Hierarchical mixture model, likelihood and broadcasting

Hi community,

I am have made a hierarchical version of a gaussian mixture model, and I would like to have access to the likelihood at each iteration of the MCMC, especially for when I run multiple chains, but I am getting errors due to broadcasting even though the model seems to work fine otherwise as it correctly identifies the mixture components on simulated data.

This is the model:

def Model_HGMM(K=2, dimension=2, data=None, label=None):
    l = len(np.unique(label))

    with numpyro.plate("components", K):
        locs = numpyro.sample(
            "locs",
            dist.MultivariateNormal(jnp.zeros(dimension), 10 * jnp.eye(dimension)),
        )
        sigma = numpyro.sample("sigma", dist.LKJ(dimension, concentration=1))
        
    with numpyro.plate("dimension", dimension):
        with numpyro.plate("components", K):
            with numpyro.plate("Age group", l):
                locs_perturb = numpyro.sample(
                 "locs_perturb",
                    dist.Normal(loc=0,scale=0.1), 
            )
    with numpyro.plate("Age group", l):
        cluster_proba = numpyro.sample(
            "cluster_proba", dist.Dirichlet(jnp.ones(K))
        )

    
    print("locs shape",locs.shape)
    print("locs_perturb shape",locs_perturb.shape)
    print("sigma shape",sigma.shape)

    with numpyro.plate("data", len(data)):
        assignment = numpyro.sample(
            "assignment",
            dist.Categorical(cluster_proba[label]),
            infer={"enumerate": "parallel"},
        )
        print("locs[assignment] shape",locs[assignment].shape)
        print("locs_perturb[label, assignment, :] shape",locs_perturb[label, assignment, :].shape)
        print("sigma[assignment] shape",sigma[assignment].shape)

        numpyro.sample(
            "obs",
            dist.MultivariateNormal(
                locs[assignment] + locs_perturb[label, assignment, :],
                covariance_matrix=sigma[assignment],
            ),
            obs=data,
        )

And this is the error I am getting when I am trying to get the likelihood using numpyro handlers.

ValueError: Incompatible shapes for broadcasting: shapes=[(600,), (600, 3)]

I have made a Jupyter notebook on simulated data to make it clearer.

Thanks in advance for your help!

Hi victor, welcome back to the forum.

This is a weird one. Given it runs, there can’t be too much wrong with the shapes.

So looking at your notebook, you’re running

mcmc_samples = mcmc.get_samples()
conditioned_model = substitute(Model_HGMM, mcmc_samples)

traced_model = trace(conditioned_model).get_trace(K=2, dimension=2, data=data, label=label)
log_likelihood = traced_model['obs']['fn'].log_prob(traced_model['obs']['value'])

I assume you copied the last line straight from the tutorial in the docs, so the code probably works there. The line is failing before the log_likelihood calculation, so it’s failing at trace(). It’s really weird that the model runs but is failing on trace.

I’m unfamiliar with the infer={"enumerate": "parallel"}, but the error suggests they are the lines of the model causing the issue here. I think @fritzo wrote that code, so they might be able to provide some more insight. Otherwise @fehiepsi is the best ever at fixing shape issues. But I’ll have a longer thing about it too.

Sorry for not being that helpful

PS: For the broadcasting, 600 is the length of the data (so size of the plate), and 3 is the number of labels (len(np.unique(label))). You could almost definitely tidy up this model using the dim argument in plate, which generally helped avoid broadcasting errors. Although it’s slightly trickier with covariance matrices which take up 2-dims.