I hope to stochastically cluster the dimensions of some data `X`

, and then perform operations within each cluster of dimensions before adding each cluster’s outputs together.

```
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
X = jnp.array([[1,2,3],[4,5,6]]) # data matrix; rows are different dpts, cols are dimensions
probs = jnp.array([0.1, 0.5, 0.4]) # Probability of some dimension being in each cluster
## Parts of the model definition omitted here
K = len(probs) # Number of clusters available
with numpyro.plate('cluster_assignments', X.shape[1]):
z = numpyro.sample('z', dist.Categorical(probs))
f_is = jnp.zeros((X.shape[0], K))
for i in range(K):
f_is = f_is.at[:,i].set(jnp.sum(X[:, jnp.where(z==i)[0]], axis=1)**2) # an example within-cluster operation
f = jnp.sum(f_is, axis=1) # sum the outputs across clusters
### More parts of the model definition omitted here
```

The problem with the code during inference is that it features a dynamically sized array due to the `jnp.where()`

function. JAX disallows such arrays and gives a `ConcretizationTypeError`

. Also, I cannot use the three-argument form of `jnp.where()`

because the resulting spacefiller trailing zeros (or any other specified filler value) would change the slicing operation on `X`

.

In contrast, rewriting the code using boolean slicing `X[:, z==i]`

leads to the `TracerBoolConversionError`

.

Is there some way in numpyro to slice `X`

based on `z`

(which is sampled and therefore a traced array) that allows for inference? Thanks!