Sample size changes over NUTS MCMC sampling

Hi Pyro people,

I’m trying to implement Dirichlet process clustering in Pyro and started with inference by MCMC sampling. Sadly my code crashes every time, because starting from the second round of MCMC sampling, one of my samples starts changing in size and I cannot find the reason behind that. Here’s a somewhat minimal example:

import torch
import pyro
import pyro.distributions as dist
from pyro.infer.mcmc import MCMC, NUTS

data = torch.ones(10)
data[:5] = 0

def model(data):

    alpha = pyro.sample("alpha", dist.Normal(5, 2)).expand(1)

    counts = torch.empty(0) # atom counts
    atoms = torch.empty(0)

    saves = torch.zeros(len(data))

    for i in range(len(data)):
        print(, alpha)))
        print(dist.Categorical(, alpha)) / (torch.sum(counts) + alpha)).sample())
        num = pyro.sample("num_{}".format(i), dist.Categorical(, alpha)) / (torch.sum(counts) + alpha)))

        if num == len(atoms):
            atom = pyro.sample("atom_{}".format(atoms.size()[0]), dist.Normal(torch.zeros(1), 1))
            atoms =, atom))
            counts =, torch.ones(1)))

            counts[num] += 1
            atom = atoms[num]

        saves[i] = pyro.sample("obs_{}".format(i), dist.Normal(atom, .1), obs=data[i])

    return saves

nuts_kernel = NUTS(model, jit_compile=False,)
posterior = MCMC(nuts_kernel,

Everything starting with the if clause in the model function can be deleted and the problem still shows, it just looks differently so I kept it in for demonstration purpose. The important thing is, in the first round (with the first alpha sample), num is always a 0-d tensor, but with the second round (second alpha), num starts gaining dimensions, which doesn’t work with the num == x clause anymore. Is that a property of the MCMC sampler or did I understand something very wrong? I tried to stick to the MCMC tutorial as much as possible.

Thanks a lot for any help.

Ok, I now learned that MCMC sampling with a variable number of discrete latent variables requires enumeration. And from a smaller example I made I learned that NUTS seems to be performing parallel enumeration automatically, which does not work because the control flow depends on whether a new atom is drawn. I can force sequential enumeration, but:

  1. This does not give correct results for very small examples that can be run with parallel enumeration
  2. The chain immediately gets stuck on a specific number of atoms and never samples from any other possibility, defeating the purpose of allowing a variable number of hidden variables

(I also read that you can only use parallel enumeration in the model?)
I am now stuck at every end with these questions, does anyone have some insight? I could provide some code demonstrating the problems if that helps.

HMC assumes a static model structure, but can deal with models like gaussian mixture models or HMMs where the discrete variables can be enumerated in parallel. Even then, sampling may be too slow and impractical for medium to large models. I would suggest looking at Variational Inference for this. You might find this discussion useful - Variational Inference for Dirichlet process clustering.

Huh ok that explains it, from the first link I mentioned I understood that it is possible to handle dynamic structure with enumeration but now I can stop trying, thanks.

I actually tried SVI at the start and also found that discussion, but gave up on it because I couldn’t get it to work. I tried again and got further this time, but I’m still having trouble with dynamic structure. For starters I would just like to sample the number of atoms at the start, but I always land outside of my array because I try to draw too high of an atom, even though they way I set it up it shouldn’t be possible and inserted print statements confirm as much:

def model(data):
    n = pyro.sample('n', dist.Categorical(torch.ones(5)))
    n = n.item() +1
    atoms = torch.zeros(n)

    for i in pyro.plate("atoms", n):
        atoms[i] = pyro.sample("atom_{}".format(i), dist.Normal(torch.zeros(1), 1))

    for i in pyro.plate("data", len(data)):
        num = pyro.sample("num_{}".format(i), dist.Categorical(torch.ones(n)))

        pyro.sample("obs_{}".format(i), dist.Normal(atoms[num], .1), obs=data[i])

    return atoms

def guide(data):
    n_probs = pyro.param('n_delta', torch.ones([5]))
    n_delta = softi(n_probs)
    n = pyro.sample('n', dist.Categorical(n_delta))
    n = n.item() +1
    atom_delta = torch.zeros(n)
    for i in pyro.plate("atoms", n):
        atom_delta[i] = pyro.param('atom_delta_{}'.format(i), torch.zeros(1))
        pyro.sample('atom_{}'.format(i), dist.Delta(atom_delta[i]))

    for i in pyro.plate("data", len(data)):
        z_i_probs_unconstrained = pyro.param("num_thing_{}".format(i), torch.ones(n))
        z_i_probs = softi(z_i_probs_unconstrained)
        pyro.sample("num_{}".format(i), dist.Categorical(z_i_probs))

    return atom_delta

Am I overlooking a stupid bug or do I need to do things differently?