Getting auto-enumeration working with transformed distributions


Hi everyone!

I am trying to get auto-enumeration working with transformed distribution, but the inference seems to break because it can’t expand the parameters correctly.
I have tried to make an MFE for a mixture model below:

def model(data):
    ays = torch.tensor([1.0, 1.0, 1.0])
    ws = pyro.sample('ws', dist.Dirichlet(ays))
    means = pyro.sample('means', dist.Uniform(-10, 30).expand_by([3])) 
    with pyro.iarange("data", len(data)) as idx:
        choice = pyro.sample('choice', dist.Categorical(probs=ws), infer=dict(enumerate='parallel'))
        base_dist = dist.Normal(means[choice], 1.0)
        transforms = [trans.identity_transform]
        pyro.sample('obs', dist.TransformedDistribution(base_dist, transforms),

def infer(data):
    nuts_kernel = NUTS(model, adapt_step_size=True, max_iarange_nesting=2)
    mcmc = MCMC(nuts_kernel, 100, 100)
    trace =
    return trace

It works if I simply use base_dist instead of the transformed distribution.

Any idea how it could be possible to solve the issue?

Thank you very much in advance.

The full notebook is available here: (please run it locally because of an issue with Collaboratory and tqdm_notebook).


@neerajprad it appears this may be a bug in pyro.distributions.TransformedDistribution, does that seem correct?

@ahmadsalim I would recommend using PyTorch 0.4.0 rather than PyTorch 0.4.1 due to numerous bugs in 0.4.1’s torch.expand() method.


Great, thanks for the reply! Will update pytorch.


The issue is that .expand for TransformedDistribution in Pyro is a bit restrictive in that it will not allow us to expand on right, e.g. expanding from a batch shape of [3, 1] to [3, 3000]. This should however work fine with the PyTorch nightly release (as the .expand call is dispatched to torch.TransformedDistribution.expand which handles such cases). Follow the instructions here if you would like to experiment with the Pyro branch tracking the upcoming release.

Why do we need such expands? While we typically don’t encounter this, it becomes a problem here because the preceding call to Categorical.sample calls .enumerate_sample under the hood collapsing the batch shape dim into a singleton dim. So we will face this issue whenever we have discrete sites and TransformedDistribution in our model. Note that discrete site enumeration is not yet released to public, so this should be fixed by the next Pyro release that use PyTorch 1.0!


Great, thanks for the response! :slight_smile:. I will see how I can get PyTorch 1.0 running.