Data shape mismatch with nested plates

I am trying to create a variant of a gaussian mixture model where each discrete latent variable gets multiple continuous observations. I was able to do so using nested plates, as follows:

def scalar_mixture(
    data=None,
    K=2,
    alpha=100,
    data_scale=1,
    cluster_scale=10,
    N_groups=10,
    N_samples_per_group=5
):
    # This is needed because of JAX's random number gen
    
    with handlers.seed(rng_seed=0):
        
        # Sample global latent variables
        theta = numpyro.sample('theta', dist.Dirichlet(alpha * np.ones(K,)))
        with numpyro.plate('components', K):
            mus = numpyro.sample('mus', dist.Normal(0, cluster_scale))
        
        # Sample data points
        with numpyro.plate('groups', N_groups):
            z_a = numpyro.sample('z_a', dist.Categorical(theta))
            with numpyro.plate('samples', N_samples_per_group):
                obs = numpyro.sample('obs', dist.Normal(mus[z_a], data_scale), obs=data)
    return theta, mus, z_a, obs

I can sample from this model using:

N_per_group=5
N_groups=100
K=2
cluster_scale=50

theta_true, mu_true, z_true, x = scalar_mixture(K=K,
                                             cluster_scale=cluster_scale,
                                             N_samples_per_group=N_per_group,
                                             N_groups=N_groups)

and as I would expect, my observations x has shape N_per_group by N_groups which is (5, 100) in this case.

However, when I go to run inference on my synthetic data I get a shape mismatch:

nuts_kernel = NUTS(scalar_mixture)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
rng_key = random.PRNGKey(1)
mcmc.run(rng_key,
         data=x,
         K=K,
         cluster_scale=cluster_scale,
         N_samples_per_group=N_per_group,
         N_groups=N_groups,
         )

this gives a long chain of jax errors where the final one seems to be a shape issue:
ValueError: Incompatible shapes for broadcasting: ((1, 1, 100), (2, 1, 5))

Does anyone know what is going on here? I thought that I should surely be able to pass data drawn from the model back to the model, but I must have set something up incorrectly.

@bantin It seems like a bug in our enum code. In the meantime, you can add dim keyword for groups and samples plates (dim=-1 and dim=-2 respectively). (Edit: this pr fixes the issue. Thank you for the report!!)

Thanks @fehiepsi! Adding the dims keyword arg fixed things for me.