I hope to stochastically cluster the dimensions of some data
X, and then perform operations within each cluster of dimensions before adding each cluster’s outputs together.
import jax.numpy as jnp import numpyro import numpyro.distributions as dist X = jnp.array([[1,2,3],[4,5,6]]) # data matrix; rows are different dpts, cols are dimensions probs = jnp.array([0.1, 0.5, 0.4]) # Probability of some dimension being in each cluster ## Parts of the model definition omitted here K = len(probs) # Number of clusters available with numpyro.plate('cluster_assignments', X.shape): z = numpyro.sample('z', dist.Categorical(probs)) f_is = jnp.zeros((X.shape, K)) for i in range(K): f_is = f_is.at[:,i].set(jnp.sum(X[:, jnp.where(z==i)], axis=1)**2) # an example within-cluster operation f = jnp.sum(f_is, axis=1) # sum the outputs across clusters ### More parts of the model definition omitted here
The problem with the code during inference is that it features a dynamically sized array due to the
jnp.where() function. JAX disallows such arrays and gives a
ConcretizationTypeError. Also, I cannot use the three-argument form of
jnp.where() because the resulting spacefiller trailing zeros (or any other specified filler value) would change the slicing operation on
In contrast, rewriting the code using boolean slicing
X[:, z==i] leads to the
Is there some way in numpyro to slice
X based on
z (which is sampled and therefore a traced array) that allows for inference? Thanks!