Check out part 2 of the tutorial that uses MCMC on the same dataset. This is the same issue as earlier, namely, getting your batch dimensions to align with pyro.plate
when using pyro.module
. pyro.module
calls pyro.sample
internally and if you are sampling anything else but pytorch scalars you will need to account for the batch dims by using pyro.plate
. This restriction might seem cumbersome, but it is needed to correctly do vectorized predictions. In many cases, you probably can get fast enough predictions without this vectorization, and I will add an option to do just that using predictive.
You can also just write your own sequential predictive function (not tested) which should work with pyro.module
:
def predict_mcmc(model, model_samples, *args, **kwargs):
preds = []
for i in range(len(model_samples)):
model_trace = poutine.trace(poutine.condition(model, model_samples)).get_trace(*args, **kwargs)
preds.append(model_trace.nodes['obs']['value'])
return torch.stack(preds)
samples = [{k: v[i] for k, v in mcmc.get_samples().items()} for i in range(num_samples)]
preds = predictive(model, samples, x_data, None)