I hope to stochastically cluster the dimensions of some data X
, and then perform operations within each cluster of dimensions before adding each cluster’s outputs together.
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
X = jnp.array([[1,2,3],[4,5,6]]) # data matrix; rows are different dpts, cols are dimensions
probs = jnp.array([0.1, 0.5, 0.4]) # Probability of some dimension being in each cluster
## Parts of the model definition omitted here
K = len(probs) # Number of clusters available
with numpyro.plate('cluster_assignments', X.shape[1]):
z = numpyro.sample('z', dist.Categorical(probs))
f_is = jnp.zeros((X.shape[0], K))
for i in range(K):
f_is = f_is.at[:,i].set(jnp.sum(X[:, jnp.where(z==i)[0]], axis=1)**2) # an example within-cluster operation
f = jnp.sum(f_is, axis=1) # sum the outputs across clusters
### More parts of the model definition omitted here
The problem with the code during inference is that it features a dynamically sized array due to the jnp.where()
function. JAX disallows such arrays and gives a ConcretizationTypeError
. Also, I cannot use the three-argument form of jnp.where()
because the resulting spacefiller trailing zeros (or any other specified filler value) would change the slicing operation on X
.
In contrast, rewriting the code using boolean slicing X[:, z==i]
leads to the TracerBoolConversionError
.
Is there some way in numpyro to slice X
based on z
(which is sampled and therefore a traced array) that allows for inference? Thanks!