Numpyro ValueError - MultivariateNormal distribution got invalid covariance_matrix parameter

def model(intensity, participant, level, mep_size_obs=None):

    a_level_mean_global_scale = numpyro.sample('a_level_mean_global_scale', dist.HalfNormal(5.0*jnp.ones(4)))
    b_level_mean_global_scale = numpyro.sample('b_level_mean_global_scale', dist.HalfNormal(5.0*jnp.ones(4)))

    a_level_scale_global_scale = numpyro.sample('a_level_scale_global_scale', dist.HalfNormal(jnp.ones(4)))
    b_level_scale_global_scale = numpyro.sample('b_level_scale_global_scale', dist.HalfNormal(jnp.ones(4)))

    n_participants = np.unique(participant).shape[0]
    n_levels = np.unique(level).shape[0]

    with numpyro.plate("n_levels", n_levels, dim=-2):

        a_level_mean = numpyro.sample("a_level_mean", dist.HalfNormal(a_level_mean_global_scale)).reshape(n_levels, 1, 4)
        b_level_mean = numpyro.sample("b_level_mean", dist.HalfNormal(b_level_mean_global_scale)).reshape(n_levels, 1, 4)

        a_level_scale = numpyro.sample("a_level_scale", dist.HalfNormal(a_level_scale_global_scale))
        b_level_scale = numpyro.sample("b_level_scale", dist.HalfNormal(b_level_scale_global_scale))

        with numpyro.plate("n_participants", n_participants, dim=-1):
            
            a = numpyro.sample("a", dist.MultivariateNormal(a_level_mean, jnp.diag(a_level_scale)))
            b = numpyro.sample("b", dist.MultivariateNormal(b_level_mean, jnp.diag(b_level_scale)))

    sigma = numpyro.sample('sigma', dist.HalfCauchy(3*jnp.ones(4)))
    cov = jnp.diag(sigma)

    mean = jax.nn.relu(jnp.multiply(b[level, participant], jnp.tile(intensity, (4,1)).T - a[level, participant]))

    with numpyro.plate("data", len(intensity)):
        return numpyro.sample("obs", dist.MultivariateNormal(mean, cov), obs=mep_size_obs)
        
numpyro.render_model(model, model_args=(intensity, participant, level, mep_size))

My model is rendering properly and I can see the plates figure. However, when I do:

    nuts_kernel = NUTS(model)
    mcmc = MCMC(nuts_kernel, num_samples=10000, num_warmup=10000)
    rng_key = random.PRNGKey(1)
    mcmc.run(rng_key, intensity, participant, level, mep_size)

    posterior_samples = mcmc.get_samples()

I get the following error:

ValueError: MultivariateNormal distribution got invalid covariance_matrix parameter.
    print(a_level_mean_global_scale.shape)
    print(a_level_scale_global_scale.shape)

    print(a_level_mean.shape)
    print(a_level_scale.shape)

    print(a.shape)
    print(mean.shape, cov.shape)

gives

(4,)
(4,)

(6, 1, 4)
(6, 4)

(6, 18, 4)
(1497, 4) (4, 4)

If I replace

            a = numpyro.sample("a", dist.MultivariateNormal(a_level_mean, jnp.diag(a_level_scale)))
            b = numpyro.sample("b", dist.MultivariateNormal(b_level_mean, jnp.diag(b_level_scale)))

with

            a = numpyro.sample("a", dist.MultivariateNormal(a_level_mean, jnp.diag(jnp.ones(4))))
            b = numpyro.sample("b", dist.MultivariateNormal(b_level_mean, jnp.diag(jnp.ones(4))))

the code works properly.

I figured the error is coming from jnp.diag(a_level_scale) because a_level_scale is (6,4), which is 2D. Looking at jnp.diag documentation, I found that jnp.diag(a_level_scale) results in a (4,) array instead of a (6,4,4) array.

Is there a work around for this?

If you want to create diagonal matrix, you can use jax.vmap(jnp.diag)(...)

Hi @fehiepsi thanks for your reply. Could you please tell me what should be the shape of covariance matrix fed to a = numpyro.sample(“a”, dist.MultivariateNormal(a_level_mean, jnp.diag(a_level_scale)))? When I’m passing jnp.tile( jnp.diag(a_level_scale), (6,1,1)), which is (6,4,4) shape, I’m getting broadcast error.

I’m getting the same error with jax vmap as I’m getting when I pass jnp.tile( jnp.diag(a_level_scale), (6,1,1))

It doesn’t like covariance shape of (6,4,4) which has to be broadcasted to (6,18)

ValueError: Incompatible shapes for broadcasting: ((1, 18), (6, 4))

Instead if I pass, jnp.tile(jnp.diag(jnp.ones(4)), (6,18,1,1)) which is of shape (6,18,4,4), the code runs properly.

@fehiepsi thanks for your help. I made it work using

            a = numpyro.sample("a", dist.MultivariateNormal(a_level_mean, \
                jnp.tile(jax.vmap(jnp.diag)(a_level_scale)[:,jnp.newaxis,:,:], (1,18,1,1))))
            b = numpyro.sample("b", dist.MultivariateNormal(b_level_mean, \
                jnp.tile(jax.vmap(jnp.diag)(b_level_scale)[:,jnp.newaxis,:,:], (1,18,1,1))))

but mcmc is taking 2 hours to run on this. Is there anyway to optimize this? (it might already be, just that it’s due to multivariate nature of it)

I’m quite understand the model (e.g. why using MVNormal instead of Normal) but I guess you need

a_level_mean = numpyro.sample("a_level_mean", dist.HalfNormal(a_level_mean_global_scale).to_event(1))
...
a = numpyro.sample("a", dist.MultivariateNormal(a_level_mean, a_level_scale[..., None] * jnp.eye(4)))
# or a = numpyro.sample("a", dist.Normal(a_level_mean, a_level_scale).to_event(1))