Hi,
So i want to know if i can get the Kl divergence term from the loss.
The following code outputs the normalized loss per epoch of training:
for epoch in range(NUM_EPOCHS):
epoch_loss=0
for x,label,_ in train_loader:
epoch_loss += svi.step(x,label)
# return normalized epoch loss
normalizer_train = len(train_loader)
total_epoch_loss_train = epoch_loss / normalizer_train
The loss is the negative ELB0, meaning : loss= -reconstuction term + KL(posterior|prior)
Is it any way to get also separately the KL ,when using
svi = SVI(vae.model, vae.guide, optimizer, loss = Trace_ELBO())
?
thanks!
1 Like
one thing you can do is poutine.scale
by some very small number (e.g. 1.0e-9
) all the observations in the model, similar to what is done here and then svi.evaluate_loss()
will give you (a stochastic estimate of) the KL
I had the same question, Extract KL loss in VAE type models from SVI?, I think what I propose as the solution in that thread is…correct? I’m using a Gaussian w/ diagonal covariance as my prior so I also tried calculating it analytically as well which gives similar results as my semi-hacky solution. One minor discrepancy from using the trace is that Pyro doesn’t calculate it analytically, it approximates it with samples from the re-parameterized distribution (at least in the Gaussian case…) so doing it that way will vary slightly with each call to the trace. Again, not 100% but this seems to be the case.
if you want analytic KL divergences you need to use TraceMeanField_ELBO
. the elegant way to do this sort of thing is to use pyro.poutine.block
:
import torch
import pyro
from pyro.distributions import Normal
from pyro.infer import SVI, TraceMeanField_ELBO
from pyro.optim import Adam
from pyro.poutine import block
def model():
pyro.sample("z", Normal(0.0, 1.0))
pyro.sample("obs", Normal(0.0, 2.0), obs=torch.tensor([0.3]))
def guide():
pyro.sample("z", Normal(0.0, 1.0))
model_no_obs = block(model, hide=["obs"])
optim = Adam({"lr": 0.001})
svi = SVI(model_no_obs, guide, optim, TraceMeanField_ELBO())
print("KL( Normal(0,1) | Normal(0,1) ) = ", svi.evaluate_loss())
which will give the deterministic result 0.0 as expected
2 Likes