I am trying to get auto-enumeration working with transformed distribution, but the inference seems to break because it can't expand the parameters correctly.
I have tried to make an MFE for a mixture model below:
ays = torch.tensor([1.0, 1.0, 1.0])
ws = pyro.sample('ws', dist.Dirichlet(ays))
means = pyro.sample('means', dist.Uniform(-10, 30).expand_by())
with pyro.iarange("data", len(data)) as idx:
choice = pyro.sample('choice', dist.Categorical(probs=ws), infer=dict(enumerate='parallel'))
base_dist = dist.Normal(means[choice], 1.0)
transforms = [trans.identity_transform]
pyro.sample('obs', dist.TransformedDistribution(base_dist, transforms),
nuts_kernel = NUTS(model, adapt_step_size=True, max_iarange_nesting=2)
mcmc = MCMC(nuts_kernel, 100, 100)
trace = mcmc.run(data)
It works if I simply use
base_dist instead of the transformed distribution.
Any idea how it could be possible to solve the issue?
Thank you very much in advance.
The full notebook is available here: https://colab.research.google.com/drive/1_FpevpwjIugN0MFC4UtmO43_yyzApnPg (please run it locally because of an issue with Collaboratory and tqdm_notebook).