Maybe we can approximate it by…
As you point out, it will be easiest to just reuse your posterior samples and do a Monte Carlo estimate. You don’t need Predictive for this, but you can use the Predictive.get_vectorized_trace
method for convenience. Note that this will only work if you have all your batch dims correctly annotated using pyro.plate
. If that’s not the case, you can also just iterate through the samples sequentially, and use poutine.condition
to condition your model on posterior samples to get model traces (one for each sample), instead of a vectorized trace as below.
samples = mcmc.get_samples()
pred = Predictive(model, samples)
tr = pred.get_vectorized_trace(data)
tr.compute_log_prob()
# sum out log probs for the batch dims
logy = sum_rightmost(tr.nodes['y']['log_prob'], -1)
# get MC estimate
print(logy.logsumexp() - log(num_samples))
Then, for each parameter, we input x_new and predict a y. Finally we use the mean of all y as our prediction.
I don’t fully follow this, but you can use the predictions however you want. Maybe the snippet above clarifies that.