# Help with using nested plates with multivariate Gaussian hierarchical model

My data set looks like this:

``````participant, level, intensity, output

participant_1, level_1, intensity_1_11, [output_1_11]
participant_1, level_1, intensity_2_11, [output_2_11]
...
participant_1, level_1, intensity_100_11, [output_100_11]

participant_1, level_5, intensity_1_15, [output_1_15]
participant_1, level_5, intensity_2_15, [output_2_15]
...
participant_1, level_5, intensity_100_15, [output_100_15]

...
participant_2, level_1, intensity_1_21, [output_1_21]
participant_2, level_1, intensity_2_21, [output_2_21]
...
participant_2, level_1, intensity_100_21, [output_100_21]

participant_2, level_2, intensity_1_22, [output_1_22]
participant_2, level_2, intensity_2_22, [output_2_22]
...
participant_2, level_2, intensity_100_22, [output_100_22]
participant_2, level_3, intensity_1_23, [output_1_23]
participant_2, level_3, intensity_2_23, [output_2_23]
...
participant_2, level_3, intensity_100_23, [output_100_23]
...

``````

where the output column elements (such as [output_1_11]) are vectors that contain 4 elements.

I’m trying to build a very basic hierarchical model

``````def model(intensity, participant, level, output=None):

n_participants = np.unique(participant).shape[0]
n_levels = np.unique(level).shape[0]

with numpyro.plate("n_levels", n_levels, dim=-2):
a_level_mean = numpyro.sample("a_level_mean", dist.HalfCauchy(jnp.ones(4)))
b_level_mean = numpyro.sample("b_level_mean", dist.HalfCauchy(jnp.ones(4)))

with numpyro.plate("n_participants", n_participants, dim=-1):

a = numpyro.sample("a", dist.MultivariateNormal(a_level_mean, jnp.diag(jnp.ones(4))))
b = numpyro.sample("b", dist.MultivariateNormal(b_level_mean, jnp.diag(jnp.ones(4))))

sigma = numpyro.sample('sigma', dist.HalfCauchy(3*jnp.ones(4)))
cov = jnp.diag(sigma)

mean = jax.nn.relu( jnp.multiply(b[level, participant],  jnp.tile(intensity, (4,1)).T - a[level, participant]))

with numpyro.plate("data", len(intensity)):
return numpyro.sample("obs", dist.MultivariateNormal(mean, covariance_matrix=cov), obs=output)

numpyro.render_model(model, model_args=(intensity, participant, level, output), filename='model.png')
``````

This throws the following error:

``````ValueError: Incompatible shapes for broadcasting: ((18,), (6,))
``````

Here, n_participants = 18, n_levels = 6

Also, the ‘output’ passed to numpyro.render_model is 2-dimensional with shape (len(intensity), 4) where len(intensity) is the number of total observations I have in my data.

If in the above code snippet, I replace

``````            a = numpyro.sample("a", dist.MultivariateNormal(a_level_mean, jnp.diag(jnp.ones(4))))
b = numpyro.sample("b", dist.MultivariateNormal(b_level_mean, jnp.diag(jnp.ones(4))))
``````

by

``````            a = numpyro.sample("a", dist.MultivariateNormal(jnp.ones(4), jnp.diag(jnp.ones(4))))
b = numpyro.sample("b", dist.MultivariateNormal(jnp.ones(4), jnp.diag(jnp.ones(4))))
``

the model renders as follows

![Screen Shot 2022-06-16 at 10.54.47 AM|690x293](upload://tmh7iM25kj5kkTfyDCBVc777JuL.png)

Can you check the shape of a `a_level_mean` variable? Note that `dist.MultivariateNormal(a_level_mean, jnp.diag(4))` should have a batch shape that is broadcastable to `(n_levels, n_participants)` as those are declared by `plate` contexts as conditionally independent.