I’m trying to understand the following code snippet which includes nested plates
n_participants = 10
n_levels = 6
with numpyro.plate("n_levels", n_levels, dim=-2):
a_level_mean = numpyro.sample("a_level_mean", dist.HalfNormal(3.0))
with numpyro.plate("n_participants", n_participants, dim=-1):
a = numpyro.sample("a", dist.Normal(a_level_mean, 1.0))
The above will have a_level_mean of shape (6,1) and a of shape (6,10)
However, it is not clear if the above code signifies the hierarchy
a[L, P] ~ N( a_level_mean[L], 1.0) for all L=1,2,…6 and P=1,2,…,10
which is what I’m looking for. It’s not clear, because in the inner plate, a_level_mean is ambiguous as it’s of (6,1) shape. (as a_level_mean is not subscripted)
Is there a way to verify this? If above does not induce that hierarchy, is there a way to achieve it?