Sampling from posterior predictive

Hello,

first off, amazing job on Pyro! Major kudos :slight_smile:
How do I sample from the posterior predictive for an SVI-trained model efficiently? At the moment, I sample a guide trace for each desired posterior predictive sample, replay the model with the guide trace, and sample once from it, like this:

    ppc = []
    dummy_obs = torch.zeros((1,self.D))
    for sample in range(n_samples):
        guide_trace = pyro.poutine.trace(self.guide).get_trace(dummy_obs)
        posterior_predictive = pyro.poutine.trace(pyro.poutine.replay(self.model, guide_trace)).get_trace(dummy_obs)
        ppc.append(posterior_predictive.nodes['obs']['value'].detach().numpy())
    np.squeeze(np.array(ppc))

Is there a better way?

edit:
After a bit more trial and error, I arrived at this:

    def posterior_predictive(self, n_samples = None):
        if n_samples is None:
            n_samples = self.N
        dummy_obs = torch.zeros((1,D))
        with pyro.plate('n_samples', n_samples, dim=-2):
            # sample latent variables from guide
            guide_trace = pyro.poutine.trace(self.guide).get_trace(dummy_obs)
            # sample observations given latent variables
            blockreplay = pyro.poutine.block(fn = pyro.poutine.replay(self.model, guide_trace),expose=['obs'])
            posterior_predictive = pyro.sample('pred_obs', blockreplay, dummy_obs)
        return posterior_predictive
1 Like

This should work fine. Just one thing to note - if you have an obs keyword to designate observed sites in your model, you will have to make sure that you are not returning the data itself due to the presence of the obs keyword. One way to do that would be to have your model take in observed data via a separate default kwarg, as in the example below. When you want to sample from the posterior predictive, you can omit y (obs=None will not constrain the sample site and allow you to sample from the predictive distribution instead). Another way might be to use poutine.condition instead of having obs= as part of sample statements in the model.

def model(x, y=None):
    ..
    return pyro.sample("y", dist, obs=y)

I think using poutines directly as in your example code, and becoming comfortable with them is a great idea. You can also use the TracePredictive class to do this for you. e.g.

svi = SVI(model, guide, num_samples=n_samples, ..).run(x, y)
trace_pred = TracePredictive(model, svi, num_samples=n_samples).run(x) 

You can sample from trace_pred directly by calling trace_pred() or access all the traces directly via trace_pred.exec_traces.

Thanks for the answer! I’m not sure if I’m using it wrong, but TracePredictive seems to be slower…

    def posterior_predictive(self, n_samples = None):
        if n_samples is None:
            n_samples = self.N
        dummy_obs = torch.zeros((1,D))
        guide_trace = pyro.poutine.trace(self.guide).get_trace(dummy_obs)
        plate_stack_depth = max([len(node['cond_indep_stack']) for name,node in guide_trace.nodes.items() if 'cond_indep_stack' in node])
        with pyro.plate('n_samples', n_samples, dim=-(plate_stack_depth+1)):
            # sample latent variables from guide
            guide_trace = pyro.poutine.trace(self.guide).get_trace(dummy_obs)
            # sample observations given latent variables
            blockreplay = pyro.poutine.block(fn = pyro.poutine.replay(self.model, guide_trace),expose=['obs'])
            posterior_predictive = pyro.sample('pred_obs', blockreplay, dummy_obs)
        return posterior_predictive
    
    def posterior_predictive2(self, n_samples = None):
        if n_samples is None:
            n_samples = self.N
        svi2 = SVI(self.model, self.guide, num_samples=n_samples, optim = self.optim, loss=self.elbo).run(torch.zeros((1,D)))
        return pyro.infer.TracePredictive(self.model, svi2, num_samples=n_samples).run(torch.zeros((1,D)))
%time ig.posterior_predictive(n_samples=1000)
CPU times: user 0 ns, sys: 4 ms, total: 4 ms
Wall time: 4.69 ms

%time ig.posterior_predictive2(n_samples=1000)
CPU times: user 2.14 s, sys: 20 ms, total: 2.16 s
Wall time: 2.16 s

Initializing a new SVI object may be the bottleneck here, but I guess there’s no way around that if I want more samples than the number used in the original SVI?

It is hard to comment on the difference in timing without seeing the full code. A couple of observations though:

  • The first function already assumes SVI has run and the parameters in your param store are fitted. The second version would end up fitting the parameters first when you call svi.run, and then sampling from the posterior predictive. However, the second version as of now is missing num_steps (i.e. run SVI for number of steps given by num_steps to fit the variational parameters), so it is not clear to me if it is actually sampling from the posterior predictive at all. I would suggest plotting the results from the two functions to ensure that you get identical results. You might find this post useful - TracePredictive worse than sampling guides?.
  • That said, given the implementation of TracePredictive, I would expect it to be slower than your first method, because it does the replaying twice, once to capture the posterior traces, and next, to resample from these and replay the model forward, which isn’t needed for SVI but is required for MCMC and Importance sampling. I think this should be addressed as part of Make AbstractInfer classes lazily consume traces · Issue #1725 · pyro-ppl/pyro · GitHub.
2 Likes