Extract KL loss in VAE type models from SVI?

I’m looking to extract the exact loss for different parts of my model during the optimization process. So like in the typical Variational Autoencoder you can rewrite the ELBO as something like:

reconstruction_loss + KL_loss

So what I want to know is what’s the best way to get each of these individual loss terms from pyro?

A suggestion I’d heard previously is running the trace manually then going to each sample site like follows:

trace = poutine.trace(self.guide).get_trace(src, trg, src_mask, trg_mask, src_lengths, trg_lengths, trg_y, kl_anneal)
trace = poutine.trace(poutine.replay(self.model, trace)).get_trace(src, trg, src_mask, trg_mask, src_lengths, trg_lengths, trg_y, kl_anneal)
trace.compute_log_prob()
loss = -1.0 * trace.nodes[‘preds’][‘log_prob_sum’] #presumed reconstruction term
kl_term = trace.nodes[‘z’][‘log_prob_sum’] #presumed KL term

where ‘z’ is my latent variable and preds are the things I am predicting.

I ask this because my other thought about trying to isolate the KL value by instead defining another guide and model but just include the parts up until the sample site ‘z’ and then just call evaluate_loss on this but when I do this I end up with completely different values and am not quite sure which one to use.

Any thoughts appreciated!

I think… I may have answered this for myself (but, if anybody reading this see this and can verify this that’d be great!!!)

So, the part my little code snippet is missing (I think anyways) is that right now my supposed “KL” term is actually just the log probability of the samples from my guide…as evaluated against the distribution parameters of my model?? to then actually properly calculate the KL… you need the trace for both model and guide…and then instead the actual KL is something like:

guide_trace = poutine.trace(self.guide).get_trace(src, trg, src_mask, trg_mask, src_lengths, trg_lengths, trg_y, kl_anneal)
model_trace = poutine.trace(poutine.replay(self.model, guide_trace)).get_trace(src, trg, src_mask, trg_mask, src_lengths, trg_lengths, trg_y, kl_anneal)
guide_trace.compute_log_prob()
model_trace.compute_log_prob()
guide_trace = poutine.trace(self.guide).get_trace(src, trg, src_mask, trg_mask, src_lengths, trg_lengths, trg_y, kl_anneal)
model_trace = poutine.trace(poutine.replay(self.model, guide_trace)).get_trace(src, trg, src_mask, trg_mask, src_lengths, trg_lengths, trg_y, kl_anneal)
guide_trace.compute_log_prob()
model_trace.compute_log_prob()
loss = -1.0 * trace.nodes[‘preds’][‘log_prob_sum’] #presumed reconstruction term
kl_term = guide_trace.noes[‘z’][‘log_prob_sum’] - model_trace.nodes[‘z’][‘log_prob_sum’] #presumed KL term

which gives an approximation of that KL term in VAEs (I assume as we are using samples from the distributions and not the analytic Gaussian form)