def model(intensity, participant, level, mep_size_obs=None):
a_level_mean_global_scale = numpyro.sample('a_level_mean_global_scale', dist.HalfNormal(5.0*jnp.ones(4)))
b_level_mean_global_scale = numpyro.sample('b_level_mean_global_scale', dist.HalfNormal(5.0*jnp.ones(4)))
a_level_scale_global_scale = numpyro.sample('a_level_scale_global_scale', dist.HalfNormal(jnp.ones(4)))
b_level_scale_global_scale = numpyro.sample('b_level_scale_global_scale', dist.HalfNormal(jnp.ones(4)))
n_participants = np.unique(participant).shape[0]
n_levels = np.unique(level).shape[0]
with numpyro.plate("n_levels", n_levels, dim=-2):
a_level_mean = numpyro.sample("a_level_mean", dist.HalfNormal(a_level_mean_global_scale)).reshape(n_levels, 1, 4)
b_level_mean = numpyro.sample("b_level_mean", dist.HalfNormal(b_level_mean_global_scale)).reshape(n_levels, 1, 4)
a_level_scale = numpyro.sample("a_level_scale", dist.HalfNormal(a_level_scale_global_scale))
b_level_scale = numpyro.sample("b_level_scale", dist.HalfNormal(b_level_scale_global_scale))
with numpyro.plate("n_participants", n_participants, dim=-1):
a = numpyro.sample("a", dist.MultivariateNormal(a_level_mean, jnp.diag(a_level_scale)))
b = numpyro.sample("b", dist.MultivariateNormal(b_level_mean, jnp.diag(b_level_scale)))
sigma = numpyro.sample('sigma', dist.HalfCauchy(3*jnp.ones(4)))
cov = jnp.diag(sigma)
mean = jax.nn.relu(jnp.multiply(b[level, participant], jnp.tile(intensity, (4,1)).T - a[level, participant]))
with numpyro.plate("data", len(intensity)):
return numpyro.sample("obs", dist.MultivariateNormal(mean, cov), obs=mep_size_obs)
numpyro.render_model(model, model_args=(intensity, participant, level, mep_size))
My model is rendering properly and I can see the plates figure. However, when I do:
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_samples=10000, num_warmup=10000)
rng_key = random.PRNGKey(1)
mcmc.run(rng_key, intensity, participant, level, mep_size)
posterior_samples = mcmc.get_samples()
I get the following error:
ValueError: MultivariateNormal distribution got invalid covariance_matrix parameter.
print(a_level_mean_global_scale.shape)
print(a_level_scale_global_scale.shape)
print(a_level_mean.shape)
print(a_level_scale.shape)
print(a.shape)
print(mean.shape, cov.shape)
gives
(4,)
(4,)
(6, 1, 4)
(6, 4)
(6, 18, 4)
(1497, 4) (4, 4)
If I replace
a = numpyro.sample("a", dist.MultivariateNormal(a_level_mean, jnp.diag(a_level_scale)))
b = numpyro.sample("b", dist.MultivariateNormal(b_level_mean, jnp.diag(b_level_scale)))
with
a = numpyro.sample("a", dist.MultivariateNormal(a_level_mean, jnp.diag(jnp.ones(4))))
b = numpyro.sample("b", dist.MultivariateNormal(b_level_mean, jnp.diag(jnp.ones(4))))
the code works properly.
I figured the error is coming from jnp.diag(a_level_scale) because a_level_scale is (6,4), which is 2D. Looking at jnp.diag documentation, I found that jnp.diag(a_level_scale) results in a (4,) array instead of a (6,4,4) array.
Is there a work around for this?