Extract the KL divergence term from loss. How?

if you want analytic KL divergences you need to use TraceMeanField_ELBO. the elegant way to do this sort of thing is to use pyro.poutine.block:

import torch
import pyro
from pyro.distributions import Normal
from pyro.infer import SVI, TraceMeanField_ELBO
from pyro.optim import Adam
from pyro.poutine import block


def model():
    pyro.sample("z", Normal(0.0, 1.0))
    pyro.sample("obs", Normal(0.0, 2.0), obs=torch.tensor([0.3]))

def guide():
    pyro.sample("z", Normal(0.0, 1.0))

model_no_obs = block(model, hide=["obs"])
optim = Adam({"lr": 0.001})
svi = SVI(model_no_obs, guide, optim, TraceMeanField_ELBO())

print("KL( Normal(0,1) | Normal(0,1) ) = ", svi.evaluate_loss())

which will give the deterministic result 0.0 as expected

2 Likes