My data set looks like this:
participant, level, intensity, output
participant_1, level_1, intensity_1_11, [output_1_11]
participant_1, level_1, intensity_2_11, [output_2_11]
...
participant_1, level_1, intensity_100_11, [output_100_11]
participant_1, level_5, intensity_1_15, [output_1_15]
participant_1, level_5, intensity_2_15, [output_2_15]
...
participant_1, level_5, intensity_100_15, [output_100_15]
...
participant_2, level_1, intensity_1_21, [output_1_21]
participant_2, level_1, intensity_2_21, [output_2_21]
...
participant_2, level_1, intensity_100_21, [output_100_21]
participant_2, level_2, intensity_1_22, [output_1_22]
participant_2, level_2, intensity_2_22, [output_2_22]
...
participant_2, level_2, intensity_100_22, [output_100_22]
participant_2, level_3, intensity_1_23, [output_1_23]
participant_2, level_3, intensity_2_23, [output_2_23]
...
participant_2, level_3, intensity_100_23, [output_100_23]
...
where the output column elements (such as [output_1_11]) are vectors that contain 4 elements.
I’m trying to build a very basic hierarchical model
def model(intensity, participant, level, output=None):
n_participants = np.unique(participant).shape[0]
n_levels = np.unique(level).shape[0]
with numpyro.plate("n_levels", n_levels, dim=-2):
a_level_mean = numpyro.sample("a_level_mean", dist.HalfCauchy(jnp.ones(4)))
b_level_mean = numpyro.sample("b_level_mean", dist.HalfCauchy(jnp.ones(4)))
with numpyro.plate("n_participants", n_participants, dim=-1):
a = numpyro.sample("a", dist.MultivariateNormal(a_level_mean, jnp.diag(jnp.ones(4))))
b = numpyro.sample("b", dist.MultivariateNormal(b_level_mean, jnp.diag(jnp.ones(4))))
sigma = numpyro.sample('sigma', dist.HalfCauchy(3*jnp.ones(4)))
cov = jnp.diag(sigma)
mean = jax.nn.relu( jnp.multiply(b[level, participant], jnp.tile(intensity, (4,1)).T - a[level, participant]))
with numpyro.plate("data", len(intensity)):
return numpyro.sample("obs", dist.MultivariateNormal(mean, covariance_matrix=cov), obs=output)
numpyro.render_model(model, model_args=(intensity, participant, level, output), filename='model.png')
This throws the following error:
ValueError: Incompatible shapes for broadcasting: ((18,), (6,))
Here, n_participants = 18, n_levels = 6
Also, the ‘output’ passed to numpyro.render_model is 2-dimensional with shape (len(intensity), 4) where len(intensity) is the number of total observations I have in my data.
If in the above code snippet, I replace
a = numpyro.sample("a", dist.MultivariateNormal(a_level_mean, jnp.diag(jnp.ones(4))))
b = numpyro.sample("b", dist.MultivariateNormal(b_level_mean, jnp.diag(jnp.ones(4))))
by
a = numpyro.sample("a", dist.MultivariateNormal(jnp.ones(4), jnp.diag(jnp.ones(4))))
b = numpyro.sample("b", dist.MultivariateNormal(jnp.ones(4), jnp.diag(jnp.ones(4))))
``
the model renders as follows
![Screen Shot 2022-06-16 at 10.54.47 AM|690x293](upload://tmh7iM25kj5kkTfyDCBVc777JuL.png)
Could someone please help me debug this?