 # Understanding nested plates in numpyro

I’m trying to understand the following code snippet which includes nested plates

``````  n_participants = 10
n_levels = 6

with numpyro.plate("n_levels", n_levels, dim=-2):
a_level_mean = numpyro.sample("a_level_mean", dist.HalfNormal(3.0))

with numpyro.plate("n_participants", n_participants, dim=-1):
a = numpyro.sample("a", dist.Normal(a_level_mean, 1.0))

``````

The above will have a_level_mean of shape (6,1) and a of shape (6,10)

However, it is not clear if the above code signifies the hierarchy

a[L, P] ~ N( a_level_mean[L], 1.0) for all L=1,2,…6 and P=1,2,…,10

which is what I’m looking for. It’s not clear, because in the inner plate, a_level_mean is ambiguous as it’s of (6,1) shape. (as a_level_mean is not subscripted)

Is there a way to verify this? If above does not induce that hierarchy, is there a way to achieve it?

It seems to me that the code does what you need. `a_level_mean.shape == (6, 1)` is expected. Why you think it is ambiguous?

``````a_level_mean.shape == (6, 1), a.shape == (6, 10)
``````

both are expected but I’m not sure if the code snippet induces the following dependency

``````a[L, P] ~ N(a_level_mean[L], 1.0)
``````

Particularly, I want a[L, P] to come from Gaussian with a_level_mean[L] mean.

Right now am not sure because in the inner plate, I can’t use subscripts:

What I really want to do is something like this:

``````for i in range(n_levels):
a_level_mean[i] = numpyro.sample(f"a_level_mean{i}", dist.HalfNormal(3.0))

for j in range(n_participants):
a[i, j] = numpyro.sample(f"a{i}{j}", dist.Normal(a_level_mean[i], 1.0))
``````

But I want to write this in using plates

Yes, it induces such dependency. One way to think about it is to broadcast `a_level_mean` to `a.shape`. If you have `a ~ Normal(0, 1)` and `a.shape == (3, 5)` then each `a[i, j]` will follow `Normal(0, 1)`. If you have `a ~ Normal(a_mean, 1)` with `a.shape == (3, 5)` and `a_mean.shape == (3, 1)`, then each `a[i, j]` will follow `Normal(a_mean[i, 0], 1)`. When you broadcast `a_mean` to `a.shape`, you will get
`a_mean_broadcasted = concatenate([a_mean, a_mean, a_mean, a_mean, a_mean, -1)`. `a[i, j]` will follow `Normal(a_mean_broadcasted[i, j], jnp.ones((3, 5))[i, j]) == Normal(a_mean[i, 0], 1)`