Any convenient way to get the ELBO for each observation?

I am following the DMM tutorial.

Besides computing the NLL and optimizing the network weight, I need to

  • get the predictive samples for each observation (‘obs_*’ in the tutorial) given a minibatch of observations, i.e. sample from f(x)=\int d(z)q(z)p(x|z)
  • compute the log likelihood (or ELBO) for each observation given a minibatch of observations.

I managed to sample the ‘obs’ following https://forum.pyro.ai/t/how-to-use-pyro-condition-with-pyro-iarange-on-iid-data-points/349:

svi = SVI(dmm.model, dmm.guide, adam, loss=elbo, num_samples=N_samp)
posterior = svi.run(mini_batch, mini_batch_reversed, mini_batch_mask,
      mini_batch_seq_lengths, 1.0)


lst_site_name = ["obs_x_%d" % (t + 1) for t in range(T_max)]
emp = pyro.infer.EmpiricalMarginal(posterior, sites=lst_site_name)
sample_weight = emp.get_samples_and_weights()
ret = sample_weight[0]

However I didn’t find a straightforward way to compute the marginalized likelihood for each observation \int q(z)p(x|z)dz given a minibatch. Do I have to loop over the observations in the minibatch then use evaluate_loss as follows?

for i in range(len(minibatch)):
    svi.evaluate_loss(minibatch[i])

Thanks!

2 Likes

Pyro can’t currently compute per-datum ELBO on a batch of data, but we’ll probably add that ability in the next six months. @neerajprad has been thinking about this Supporting independent optimization problems via tensor DSL · Issue #1330 · pyro-ppl/pyro · GitHub for training multiple independent models.

Any news on this?