I’m looking to extract the exact loss for different parts of my model during the optimization process. So like in the typical Variational Autoencoder you can rewrite the ELBO as something like:

reconstruction_loss + KL_loss

So what I want to know is what’s the best way to get each of these individual loss terms from pyro?

A suggestion I’d heard previously is running the trace manually then going to each sample site like follows:

trace = poutine.trace(self.guide).get_trace(src, trg, src_mask, trg_mask, src_lengths, trg_lengths, trg_y, kl_anneal)

trace = poutine.trace(poutine.replay(self.model, trace)).get_trace(src, trg, src_mask, trg_mask, src_lengths, trg_lengths, trg_y, kl_anneal)

trace.compute_log_prob()

loss = -1.0 * trace.nodes[‘preds’][‘log_prob_sum’] #presumed reconstruction term

kl_term = trace.nodes[‘z’][‘log_prob_sum’] #presumed KL term

where ‘z’ is my latent variable and preds are the things I am predicting.

I ask this because my other thought about trying to isolate the KL value by instead defining another guide and model but just include the parts up until the sample site ‘z’ and then just call evaluate_loss on this but when I do this I end up with completely different values and am not quite sure which one to use.

Any thoughts appreciated!