Problem in "MCMC" Sampling of Categorical Variable

I’m trying to implement some type of a mixture model with multiple components, thus want to use sampling over a categorical variable to devide the data points into clusters randomly. However, when I try to sample dist.Categorial through MCMC, even without any observations, the sample sizes suddenly changes on the second sample. The following code reproduces it.

What am I doing wrong? Should I approach this differently?
Note that just using dist.Categorical outside any model works fine.

import pyro
import pyro.distributions as dist
from pyro.infer.mcmc.api import HMC, MCMC, NUTS

def model():
    classes = pyro.sample("classes",dist.Bernoulli(torch.tensor([[0.3,0.7], [0.3,0.7],[0.3,0.7], [0.3,0.7],[0.3,0.7], [0.3,0.7],[0.3,0.7], [0.3,0.7]])))
    print(classes.shape)

mcmc = MCMC(HMC(model, target_accept_prob=0.8),
            num_samples=10,
            warmup_steps=10,
            num_chains=1)
mcmc.run()

See Sample size changes over NUTS MCMC sampling.

HMC/NUTS requires models with continuous parameters. If a model has discrete parameters, then HMC enumerates them in parallel to compute the log density of a model execution trace. This parallel enumeration requires designating and using additional batch dims to record subsequent parameter values in the model for each realization of the categorical variable.

2 follow up questions:

  1. Why do HMC/NUTS require continuous parameters?
  2. How does one then sample from the posterior of cluster assignments in, say, a mixture of Gaussians? Is there a code snippet?

HMC requires gradient for its dynamics, so we need continuous variables. Recently, there are several research works on making HMC work with discrete latent variables. This MixedHMC paper summaries some progresses in that direction. For the second question, I think you can use infer_discrete.

It looks like infer_discrete is a MAP estimator? It’s pretty frustrating that I have to do all this extra work, with a high probability of screwing something up, for such a simple model.

It is not just Viterbi-like MAP inference, infer_discrete will sample discrete latent variables via forward-filter backward-sample by default (see the temperature parameter in the docs.)

IMO, working with mixture models is not easy. If you want something out-of-the-box, I would recommend using DiscreteHMCGibbs (the example in the docs is 1D GMM). That class will give both discrete and continuous latent variables for you, without extra work. In some cases, HMC within Gibbs algorithm will perform better than marginalization, but in many cases (up to my knowledge), marginalization will perform better.

Thanks for pointing me in the right direction! I’m a bit confused by the samples output by DiscreteHMCGibbs. I’ve included my code below. The confusing bit is that if I have C clusters, C-1 mixture proportions, C cluster parameters, and D data points, I would expect that a sample of cluster assignments should have shape (num samples, num data points) i.e. one cluster assignment for each data point. But instead, I find that the output shape of the cluster assignments is (num samples,). Why?

    def model(obs):
        with numpyro.plate('beta_plate', sampling_max_num_clusters - 1):
            beta = numpyro.sample(
                'beta',
                numpyro.distributions.Beta(1, alpha))

        with pyro.plate('mean_plate', sampling_max_num_clusters):
            mean = numpyro.sample(
                'mean',
                numpyro.distributions.MultivariateNormal(
                    jnp.zeros(obs_dim),
                    gaussian_mean_prior_cov_scaling * jnp.eye(obs_dim)))

        with pyro.plate('data', num_obs):
            z = numpyro.sample(
                'z',
                numpyro.distributions.Categorical(mix_weights(beta=beta)).mask(False))
            numpyro.sample(
                'obs',
                numpyro.distributions.MultivariateNormal(
                    mean[z],
                    gaussian_cov_scaling * jnp.eye(obs_dim)),
                obs=obs)

    hmc_kernel = numpyro.infer.NUTS(model)
    kernel = numpyro.infer.DiscreteHMCGibbs(hmc_kernel, modified=True)
    mcmc = numpyro.infer.MCMC(kernel, num_warmup=100, num_samples=89, progress_bar=True)
    mcmc.run(random.PRNGKey(0), obs=observations)
    mcmc.print_summary()
    samples = mcmc.get_samples()

The short tutorial is the same: the number of sampled Categorical variables has shape (num samples,) rather than (num samples, num observations). Shouldn’t there be one sampled Categorical per observation?

    def model(probs, locs):
        c = numpyro.sample("c", numpyro.distributions.Categorical(probs))
        numpyro.sample("x", numpyro.distributions.Normal(locs[c], 0.5))

You are right. There should be a bug somewhere in the current code. It is surprising to me that we don’t have tests for discrete sites under plate yet. I’ll address it soon. Could you help me make a github issue, so I won’t forget. Thanks!

Happy to create a GitHub issue :slight_smile: Will link in just a bit.

Issue is here: DiscreteHMCGibbs Draws Samples of Incorrect Shape · Issue #931 · pyro-ppl/numpyro · GitHub

1 Like