I am trying to get auto-enumeration working with transformed distribution, but the inference seems to break because it can’t expand the parameters correctly.
I have tried to make an MFE for a mixture model below:
def model(data): ays = torch.tensor([1.0, 1.0, 1.0]) ws = pyro.sample('ws', dist.Dirichlet(ays)) means = pyro.sample('means', dist.Uniform(-10, 30).expand_by()) with pyro.iarange("data", len(data)) as idx: choice = pyro.sample('choice', dist.Categorical(probs=ws), infer=dict(enumerate='parallel')) base_dist = dist.Normal(means[choice], 1.0) transforms = [trans.identity_transform] pyro.sample('obs', dist.TransformedDistribution(base_dist, transforms), obs=data[idx]) def infer(data): nuts_kernel = NUTS(model, adapt_step_size=True, max_iarange_nesting=2) mcmc = MCMC(nuts_kernel, 100, 100) trace = mcmc.run(data) return trace
It works if I simply use
base_dist instead of the transformed distribution.
Any idea how it could be possible to solve the issue?
Thank you very much in advance.
The full notebook is available here: https://colab.research.google.com/drive/1_FpevpwjIugN0MFC4UtmO43_yyzApnPg (please run it locally because of an issue with Collaboratory and tqdm_notebook).