I am following the DMM tutorial.
Besides computing the NLL and optimizing the network weight, I need to
- get the predictive samples for each observation (‘obs_*’ in the tutorial) given a minibatch of observations, i.e. sample from f(x)=\int d(z)q(z)p(x|z)
- compute the log likelihood (or ELBO) for each observation given a minibatch of observations.
I managed to sample the ‘obs’ following https://forum.pyro.ai/t/how-to-use-pyro-condition-with-pyro-iarange-on-iid-data-points/349:
svi = SVI(dmm.model, dmm.guide, adam, loss=elbo, num_samples=N_samp)
posterior = svi.run(mini_batch, mini_batch_reversed, mini_batch_mask,
mini_batch_seq_lengths, 1.0)
lst_site_name = ["obs_x_%d" % (t + 1) for t in range(T_max)]
emp = pyro.infer.EmpiricalMarginal(posterior, sites=lst_site_name)
sample_weight = emp.get_samples_and_weights()
ret = sample_weight[0]
However I didn’t find a straightforward way to compute the marginalized likelihood for each observation \int q(z)p(x|z)dz given a minibatch. Do I have to loop over the observations in the minibatch then use evaluate_loss as follows?
for i in range(len(minibatch)):
svi.evaluate_loss(minibatch[i])
Thanks!