Dynamically sized array problem: slicing array with indices from another traced array

I hope to stochastically cluster the dimensions of some data X, and then perform operations within each cluster of dimensions before adding each cluster’s outputs together.

import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist

X = jnp.array([[1,2,3],[4,5,6]]) # data matrix; rows are different dpts, cols are dimensions
probs = jnp.array([0.1, 0.5, 0.4]) # Probability of some dimension being in each cluster

## Parts of the model definition omitted here
    K = len(probs) # Number of clusters available
    with numpyro.plate('cluster_assignments', X.shape[1]):
        z = numpyro.sample('z', dist.Categorical(probs))
  
    f_is = jnp.zeros((X.shape[0], K))
    for i in range(K):
        f_is = f_is.at[:,i].set(jnp.sum(X[:, jnp.where(z==i)[0]], axis=1)**2) # an example within-cluster operation
    f = jnp.sum(f_is, axis=1) # sum the outputs across clusters

### More parts of the model definition omitted here

The problem with the code during inference is that it features a dynamically sized array due to the jnp.where() function. JAX disallows such arrays and gives a ConcretizationTypeError. Also, I cannot use the three-argument form of jnp.where() because the resulting spacefiller trailing zeros (or any other specified filler value) would change the slicing operation on X.

In contrast, rewriting the code using boolean slicing X[:, z==i] leads to the TracerBoolConversionError.

Is there some way in numpyro to slice X based on z (which is sampled and therefore a traced array) that allows for inference? Thanks!

Looking at your code, it seems to me that z has shape (3, 3) (i.e. (X.shape[1], probs.shape[-1])). Then the operator jnp.where(z==i)[0] does not make sense. I would recommend writing jax code first, seeing if your jax code can be compiled, then making a probabilistic model later. If your purpose is to sum X per cluster, you can just simply do (X * (z == i)).sum(1) or something.

Thank you very much!

Running the code outside of the probabilistic model reveals that z actually has shape (3,) as I intended, which corresponds to a cluster assignment for each of the X.shape[1] data dimensions. f_is has shape (2,3) for the 2 datapoints and 3 available clusters. f has shape (2,), which is the output for each datapoint.

My goal is to eventually define flexible functions that are different from sum() for each cluster, and hence masking out-of-cluster dimensions with 0’s would not work then. Is there a general solution available?