Help with using nested plates with multivariate Gaussian hierarchical model

My data set looks like this:

participant, level, intensity, output

participant_1, level_1, intensity_1_11, [output_1_11]
participant_1, level_1, intensity_2_11, [output_2_11]
...
participant_1, level_1, intensity_100_11, [output_100_11]

participant_1, level_5, intensity_1_15, [output_1_15]
participant_1, level_5, intensity_2_15, [output_2_15]
...
participant_1, level_5, intensity_100_15, [output_100_15]

...
participant_2, level_1, intensity_1_21, [output_1_21]
participant_2, level_1, intensity_2_21, [output_2_21]
...
participant_2, level_1, intensity_100_21, [output_100_21]

participant_2, level_2, intensity_1_22, [output_1_22]
participant_2, level_2, intensity_2_22, [output_2_22]
...
participant_2, level_2, intensity_100_22, [output_100_22]
participant_2, level_3, intensity_1_23, [output_1_23]
participant_2, level_3, intensity_2_23, [output_2_23]
...
participant_2, level_3, intensity_100_23, [output_100_23]
...

where the output column elements (such as [output_1_11]) are vectors that contain 4 elements.

I’m trying to build a very basic hierarchical model

def model(intensity, participant, level, output=None):

    n_participants = np.unique(participant).shape[0]  
    n_levels = np.unique(level).shape[0]

    with numpyro.plate("n_levels", n_levels, dim=-2):
        a_level_mean = numpyro.sample("a_level_mean", dist.HalfCauchy(jnp.ones(4)))
        b_level_mean = numpyro.sample("b_level_mean", dist.HalfCauchy(jnp.ones(4)))

        with numpyro.plate("n_participants", n_participants, dim=-1):

            a = numpyro.sample("a", dist.MultivariateNormal(a_level_mean, jnp.diag(jnp.ones(4))))
            b = numpyro.sample("b", dist.MultivariateNormal(b_level_mean, jnp.diag(jnp.ones(4))))
    
    sigma = numpyro.sample('sigma', dist.HalfCauchy(3*jnp.ones(4)))
    cov = jnp.diag(sigma)

    mean = jax.nn.relu( jnp.multiply(b[level, participant],  jnp.tile(intensity, (4,1)).T - a[level, participant]))

    with numpyro.plate("data", len(intensity)):
        return numpyro.sample("obs", dist.MultivariateNormal(mean, covariance_matrix=cov), obs=output)
        
numpyro.render_model(model, model_args=(intensity, participant, level, output), filename='model.png')

This throws the following error:

ValueError: Incompatible shapes for broadcasting: ((18,), (6,))

Here, n_participants = 18, n_levels = 6

Also, the ‘output’ passed to numpyro.render_model is 2-dimensional with shape (len(intensity), 4) where len(intensity) is the number of total observations I have in my data.

If in the above code snippet, I replace

            a = numpyro.sample("a", dist.MultivariateNormal(a_level_mean, jnp.diag(jnp.ones(4))))
            b = numpyro.sample("b", dist.MultivariateNormal(b_level_mean, jnp.diag(jnp.ones(4))))

by

            a = numpyro.sample("a", dist.MultivariateNormal(jnp.ones(4), jnp.diag(jnp.ones(4))))
            b = numpyro.sample("b", dist.MultivariateNormal(jnp.ones(4), jnp.diag(jnp.ones(4))))
``

the model renders as follows 


![Screen Shot 2022-06-16 at 10.54.47 AM|690x293](upload://tmh7iM25kj5kkTfyDCBVc777JuL.png)


Could someone please help me debug this?

On which line do you get this error message?

Can you check the shape of a a_level_mean variable? Note that dist.MultivariateNormal(a_level_mean, jnp.diag(4)) should have a batch shape that is broadcastable to (n_levels, n_participants) as those are declared by plate contexts as conditionally independent.