I am trying to get auto-enumeration working with transformed distribution, but the inference seems to break because it can’t expand the parameters correctly.
I have tried to make an MFE for a mixture model below:
The issue is that .expand for TransformedDistribution in Pyro is a bit restrictive in that it will not allow us to expand on right, e.g. expanding from a batch shape of [3, 1] to [3, 3000]. This should however work fine with the PyTorch nightly release (as the .expand call is dispatched to torch.TransformedDistribution.expand which handles such cases). Follow the instructions here if you would like to experiment with the Pyro branch tracking the upcoming release.
Why do we need such expands? While we typically don’t encounter this, it becomes a problem here because the preceding call to Categorical.sample calls .enumerate_sample under the hood collapsing the batch shape dim into a singleton dim. So we will face this issue whenever we have discrete sites and TransformedDistribution in our model. Note that discrete site enumeration is not yet released to public, so this should be fixed by the next Pyro release that use PyTorch 1.0!