Extract the KL divergence term from loss. How?

So i want to know if i can get the Kl divergence term from the loss.

The following code outputs the normalized loss per epoch of training:

for epoch in range(NUM_EPOCHS):
    for x,label,_ in train_loader: 

        epoch_loss += svi.step(x,label)                                                              

    # return normalized epoch loss
    normalizer_train = len(train_loader)
    total_epoch_loss_train = epoch_loss / normalizer_train

The loss is the negative ELB0, meaning : loss= -reconstuction term + KL(posterior|prior)

Is it any way to get also separately the KL ,when using

svi = SVI(vae.model, vae.guide, optimizer, loss = Trace_ELBO())



1 Like

one thing you can do is poutine.scale by some very small number (e.g. 1.0e-9) all the observations in the model, similar to what is done here and then svi.evaluate_loss() will give you (a stochastic estimate of) the KL

I had the same question, Extract KL loss in VAE type models from SVI?, I think what I propose as the solution in that thread is…correct? I’m using a Gaussian w/ diagonal covariance as my prior so I also tried calculating it analytically as well which gives similar results as my semi-hacky solution. One minor discrepancy from using the trace is that Pyro doesn’t calculate it analytically, it approximates it with samples from the re-parameterized distribution (at least in the Gaussian case…) so doing it that way will vary slightly with each call to the trace. Again, not 100% but this seems to be the case.

if you want analytic KL divergences you need to use TraceMeanField_ELBO. the elegant way to do this sort of thing is to use pyro.poutine.block:

import torch
import pyro
from pyro.distributions import Normal
from pyro.infer import SVI, TraceMeanField_ELBO
from pyro.optim import Adam
from pyro.poutine import block

def model():
    pyro.sample("z", Normal(0.0, 1.0))
    pyro.sample("obs", Normal(0.0, 2.0), obs=torch.tensor([0.3]))

def guide():
    pyro.sample("z", Normal(0.0, 1.0))

model_no_obs = block(model, hide=["obs"])
optim = Adam({"lr": 0.001})
svi = SVI(model_no_obs, guide, optim, TraceMeanField_ELBO())

print("KL( Normal(0,1) | Normal(0,1) ) = ", svi.evaluate_loss())

which will give the deterministic result 0.0 as expected

1 Like