It is hard to comment on the difference in timing without seeing the full code. A couple of observations though:
- The first function already assumes
SVI
has run and the parameters in your param store are fitted. The second version would end up fitting the parameters first when you callsvi.run
, and then sampling from the posterior predictive. However, the second version as of now is missingnum_steps
(i.e. runSVI
for number of steps given bynum_steps
to fit the variational parameters), so it is not clear to me if it is actually sampling from the posterior predictive at all. I would suggest plotting the results from the two functions to ensure that you get identical results. You might find this post useful - TracePredictive worse than sampling guides?. - That said, given the implementation of
TracePredictive
, I would expect it to be slower than your first method, because it does the replaying twice, once to capture the posterior traces, and next, to resample from these and replay the model forward, which isn’t needed for SVI but is required for MCMC and Importance sampling. I think this should be addressed as part of Make AbstractInfer classes lazily consume traces · Issue #1725 · pyro-ppl/pyro · GitHub.