I am trying to create a variant of a gaussian mixture model where each discrete latent variable gets multiple continuous observations. I was able to do so using nested plates, as follows:
def scalar_mixture( data=None, K=2, alpha=100, data_scale=1, cluster_scale=10, N_groups=10, N_samples_per_group=5 ): # This is needed because of JAX's random number gen with handlers.seed(rng_seed=0): # Sample global latent variables theta = numpyro.sample('theta', dist.Dirichlet(alpha * np.ones(K,))) with numpyro.plate('components', K): mus = numpyro.sample('mus', dist.Normal(0, cluster_scale)) # Sample data points with numpyro.plate('groups', N_groups): z_a = numpyro.sample('z_a', dist.Categorical(theta)) with numpyro.plate('samples', N_samples_per_group): obs = numpyro.sample('obs', dist.Normal(mus[z_a], data_scale), obs=data) return theta, mus, z_a, obs
I can sample from this model using:
N_per_group=5 N_groups=100 K=2 cluster_scale=50 theta_true, mu_true, z_true, x = scalar_mixture(K=K, cluster_scale=cluster_scale, N_samples_per_group=N_per_group, N_groups=N_groups)
and as I would expect, my observations x has shape N_per_group by N_groups which is (5, 100) in this case.
However, when I go to run inference on my synthetic data I get a shape mismatch:
nuts_kernel = NUTS(scalar_mixture) mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000) rng_key = random.PRNGKey(1) mcmc.run(rng_key, data=x, K=K, cluster_scale=cluster_scale, N_samples_per_group=N_per_group, N_groups=N_groups, )
this gives a long chain of jax errors where the final one seems to be a shape issue:
ValueError: Incompatible shapes for broadcasting: ((1, 1, 100), (2, 1, 5))
Does anyone know what is going on here? I thought that I should surely be able to pass data drawn from the model back to the model, but I must have set something up incorrectly.