Sampling from posterior predictive

It is hard to comment on the difference in timing without seeing the full code. A couple of observations though:

  • The first function already assumes SVI has run and the parameters in your param store are fitted. The second version would end up fitting the parameters first when you call svi.run, and then sampling from the posterior predictive. However, the second version as of now is missing num_steps (i.e. run SVI for number of steps given by num_steps to fit the variational parameters), so it is not clear to me if it is actually sampling from the posterior predictive at all. I would suggest plotting the results from the two functions to ensure that you get identical results. You might find this post useful - TracePredictive worse than sampling guides?.
  • That said, given the implementation of TracePredictive, I would expect it to be slower than your first method, because it does the replaying twice, once to capture the posterior traces, and next, to resample from these and replay the model forward, which isn’t needed for SVI but is required for MCMC and Importance sampling. I think this should be addressed as part of Make AbstractInfer classes lazily consume traces · Issue #1725 · pyro-ppl/pyro · GitHub.
2 Likes