Problems with predictive for MCMC+NN

Check out part 2 of the tutorial that uses MCMC on the same dataset. This is the same issue as earlier, namely, getting your batch dimensions to align with pyro.plate when using pyro.module. pyro.module calls pyro.sample internally and if you are sampling anything else but pytorch scalars you will need to account for the batch dims by using pyro.plate. This restriction might seem cumbersome, but it is needed to correctly do vectorized predictions. In many cases, you probably can get fast enough predictions without this vectorization, and I will add an option to do just that using predictive.

You can also just write your own sequential predictive function (not tested) which should work with pyro.module:

def predict_mcmc(model, model_samples, *args, **kwargs):
    preds = []
    for i in range(len(model_samples)):
        model_trace = poutine.trace(poutine.condition(model, model_samples)).get_trace(*args, **kwargs)
        preds.append(model_trace.nodes['obs']['value'])
    return torch.stack(preds)

samples = [{k: v[i] for k, v in mcmc.get_samples().items()} for i in range(num_samples)]
preds = predictive(model, samples, x_data, None)
1 Like