I am trying to create a variant of a gaussian mixture model where each discrete latent variable gets multiple continuous observations. I was able to do so using nested plates, as follows:
def scalar_mixture(
data=None,
K=2,
alpha=100,
data_scale=1,
cluster_scale=10,
N_groups=10,
N_samples_per_group=5
):
# This is needed because of JAX's random number gen
with handlers.seed(rng_seed=0):
# Sample global latent variables
theta = numpyro.sample('theta', dist.Dirichlet(alpha * np.ones(K,)))
with numpyro.plate('components', K):
mus = numpyro.sample('mus', dist.Normal(0, cluster_scale))
# Sample data points
with numpyro.plate('groups', N_groups):
z_a = numpyro.sample('z_a', dist.Categorical(theta))
with numpyro.plate('samples', N_samples_per_group):
obs = numpyro.sample('obs', dist.Normal(mus[z_a], data_scale), obs=data)
return theta, mus, z_a, obs
I can sample from this model using:
N_per_group=5
N_groups=100
K=2
cluster_scale=50
theta_true, mu_true, z_true, x = scalar_mixture(K=K,
cluster_scale=cluster_scale,
N_samples_per_group=N_per_group,
N_groups=N_groups)
and as I would expect, my observations x has shape N_per_group by N_groups which is (5, 100) in this case.
However, when I go to run inference on my synthetic data I get a shape mismatch:
nuts_kernel = NUTS(scalar_mixture)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
rng_key = random.PRNGKey(1)
mcmc.run(rng_key,
data=x,
K=K,
cluster_scale=cluster_scale,
N_samples_per_group=N_per_group,
N_groups=N_groups,
)
this gives a long chain of jax errors where the final one seems to be a shape issue:
ValueError: Incompatible shapes for broadcasting: ((1, 1, 100), (2, 1, 5))
Does anyone know what is going on here? I thought that I should surely be able to pass data drawn from the model back to the model, but I must have set something up incorrectly.