Understanding nested plates in numpyro

I’m trying to understand the following code snippet which includes nested plates

  n_participants = 10
  n_levels = 6

  with numpyro.plate("n_levels", n_levels, dim=-2):
      a_level_mean = numpyro.sample("a_level_mean", dist.HalfNormal(3.0))

      with numpyro.plate("n_participants", n_participants, dim=-1):
          a = numpyro.sample("a", dist.Normal(a_level_mean, 1.0))

The above will have a_level_mean of shape (6,1) and a of shape (6,10)

However, it is not clear if the above code signifies the hierarchy

a[L, P] ~ N( a_level_mean[L], 1.0) for all L=1,2,…6 and P=1,2,…,10

which is what I’m looking for. It’s not clear, because in the inner plate, a_level_mean is ambiguous as it’s of (6,1) shape. (as a_level_mean is not subscripted)

Is there a way to verify this? If above does not induce that hierarchy, is there a way to achieve it?

It seems to me that the code does what you need. a_level_mean.shape == (6, 1) is expected. Why you think it is ambiguous?

a_level_mean.shape == (6, 1), a.shape == (6, 10)

both are expected but I’m not sure if the code snippet induces the following dependency

a[L, P] ~ N(a_level_mean[L], 1.0) 

Particularly, I want a[L, P] to come from Gaussian with a_level_mean[L] mean.

Right now am not sure because in the inner plate, I can’t use subscripts:

What I really want to do is something like this:

for i in range(n_levels):
    a_level_mean[i] = numpyro.sample(f"a_level_mean{i}", dist.HalfNormal(3.0))

    for j in range(n_participants):
        a[i, j] = numpyro.sample(f"a{i}{j}", dist.Normal(a_level_mean[i], 1.0))

But I want to write this in using plates

Yes, it induces such dependency. One way to think about it is to broadcast a_level_mean to a.shape. If you have a ~ Normal(0, 1) and a.shape == (3, 5) then each a[i, j] will follow Normal(0, 1). If you have a ~ Normal(a_mean, 1) with a.shape == (3, 5) and a_mean.shape == (3, 1), then each a[i, j] will follow Normal(a_mean[i, 0], 1). When you broadcast a_mean to a.shape, you will get
a_mean_broadcasted = concatenate([a_mean, a_mean, a_mean, a_mean, a_mean, -1). a[i, j] will follow Normal(a_mean_broadcasted[i, j], jnp.ones((3, 5))[i, j]) == Normal(a_mean[i, 0], 1)