Hi all,
I’m training a Bayesian neural network for classification (GSC2 speech data) using Pyro, but during training I noticed that my ELBO terms, especially the KL divergence, are either unstable or even negative.
To debug this, I implemented two decomposition methods:
SVI.evaluate_loss
at different anneal factors- Manual trace + replay decomposition (log_joint, log_guide, etc.)
Here is a minimal reproducible example where both methods show very different decomposition results. I want to confirm:
a) Are these methods equivalent in principle?
b) Which one is more reliable for debugging?
c) Is my KL calculation correct?
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.infer.autoguide import AutoNormal
from pyro import poutine
torch.manual_seed(42)
pyro.set_rng_seed(42)
x = torch.tensor([[1.0, 0.0], [0.0, 1.0]])
y = torch.tensor([0, 2])
def model(x, y, anneal_factor=1.0):
num_features = x.shape[1]
num_classes = 3
w = pyro.sample("w", dist.Normal(
torch.zeros(num_classes, num_features),
torch.ones(num_classes, num_features)
).to_event(2), infer={"scale": anneal_factor})
logits = torch.matmul(x, w.t())
with pyro.plate("data", x.shape[0]):
pyro.sample("obs", dist.Categorical(logits=logits), obs=y)
guide = ScaledAutoNormal(model)
elbo = Trace_ELBO(num_particles=10)
svi = SVI(model, guide, optim=pyro.optim.Adam({"lr": 1e-2}), loss=elbo)
def compute_elbo_decomposition_method1(svi, x, y, anneal_factor=1.0):
elbo_reconstruction = svi.evaluate_loss(x, y, anneal_factor=1e-10)
elbo_full = svi.evaluate_loss(x, y, anneal_factor=1.0)
elbo_current = svi.evaluate_loss(x, y, anneal_factor=anneal_factor)
kl_dif = elbo_full - elbo_reconstruction
return {
"method": "method1",
"reconstruction_loss": elbo_reconstruction,
"kl_divergence": kl_dif,
"elbo_full": elbo_full,
"elbo_current": elbo_current,
}
def compute_elbo_decomposition_method2(model, guide, x, y, anneal_factor=1.0, num_particles=10):
recon_losses = []
kl_divs = []
elbos = []
log_joints = []
log_guides = []
log_priors = []
for _ in range(num_particles):
guide_trace = poutine.trace(guide).get_trace(x, y)
model_trace = poutine.trace(poutine.replay(model, guide_trace)).get_trace(x, y, anneal_factor=anneal_factor)
log_joint = model_trace.log_prob_sum().item()
log_guide = guide_trace.log_prob_sum().item()
elbo = log_joint - log_guide
reconstruction_loss = model_trace.nodes["obs"]["log_prob_sum"].item()
log_prior = sum(
site["log_prob_sum"].item()
for name, site in model_trace.nodes.items()
if site["type"] == "sample" and not site["is_observed"]
)
kl_divergence = log_guide - log_prior
recon_losses.append(reconstruction_loss)
kl_divs.append(kl_divergence)
elbos.append(elbo)
log_joints.append(log_joint)
log_guides.append(log_guide)
log_priors.append(log_prior)
return {
"method": "method2",
"reconstruction_loss": sum(recon_losses) / num_particles,
"kl_divergence": sum(kl_divs) / num_particles,
"elbo": sum(elbos) / num_particles,
"elbo_annealed": sum(recon_losses) / num_particles - anneal_factor * sum(kl_divs) / num_particles,
"log_joint": sum(log_joints) / num_particles,
"log_guide": sum(log_guides) / num_particles,
"log_prior": sum(log_priors) / num_particles,
}
anneal_factor = 0.5
out1 = compute_elbo_decomposition_method1(svi, x, y, anneal_factor=anneal_factor)
out2 = compute_elbo_decomposition_method2(model, guide, x, y, anneal_factor=anneal_factor)
The scaled auto normal is defined as,
from contextlib import ExitStack
import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
import torch
from pyro.distributions import constraints
from pyro.distributions.util import sum_rightmost
from pyro.infer.autoguide.initialization import InitMessenger, init_to_feasible
from pyro.infer.autoguide.utils import deep_setattr, helpful_support_errors
from pyro.nn.module import PyroModule, PyroParam
from pyro.ops.tensor_utils import periodic_repeat
from torch.distributions import biject_to
class ScaledAutoNormal(pyro.infer.autoguide.AutoNormal):
"""
Subclassed AutoNormal guide to include kl annealing into an autoguide without having to explicitly define the guide
- rewrote forward to include kl annealing via pyro.poutine.scale() as in https://pyro.ai/examples/dmm.html
+ freezing locs possible
"""
def __init__(
self, model, *, init_loc_fn=init_to_feasible, init_scale=0.1, create_plates=None, freeze_locs=False
):
self.init_loc_fn = init_loc_fn
self.freeze_locs = freeze_locs
if not isinstance(init_scale, float) or not (init_scale > 0):
raise ValueError("Expected init_scale > 0. but got {}".format(init_scale))
self._init_scale = init_scale
model = InitMessenger(self.init_loc_fn)(model)
super().__init__(model, create_plates=create_plates)
def site_forward(self, name, site, plates, result, *args, **kwargs):
transform = biject_to(site["fn"].support)
with ExitStack() as stack:
for frame in site["cond_indep_stack"]:
if frame.vectorized:
stack.enter_context(plates[frame.name])
site_loc, site_scale = self._get_loc_and_scale(name)
unconstrained_latent = pyro.sample(
name + "_unconstrained",
dist.Normal(
site_loc,
site_scale,
).to_event(self._event_dims[name]),
infer={"is_auxiliary": True},
)
value = transform(unconstrained_latent)
if poutine.get_mask() is False:
log_density = 0.0
else:
log_density = transform.inv.log_abs_det_jacobian(
value,
unconstrained_latent,
)
log_density = sum_rightmost(
log_density,
log_density.dim() - value.dim() + site["fn"].event_dim,
)
delta_dist = dist.Delta(
value,
log_density=log_density,
event_dim=site["fn"].event_dim,
)
result[name] = pyro.sample(name, delta_dist)
def forward(self, *args, **kwargs):
"""
An automatic guide with the same ``*args, **kwargs`` as the base ``model``.
.. note:: This method is used internally by :class:`~torch.nn.Module`.
Users should instead use :meth:`~torch.nn.Module.__call__`.
:return: A dict mapping sample site name to sampled value.
:rtype: dict
"""
# if we've never run the model before, do so now so we can inspect the model structure
if self.prototype_trace is None:
self._setup_prototype(*args, **kwargs)
plates = self._create_plates(*args, **kwargs)
result = {}
for name, site in self.prototype_trace.iter_stochastic_nodes():
with pyro.poutine.scale(scale=kwargs.get("anneal_factor", 1.0)):
self.site_forward(name, site, plates, result, *args, **kwargs)
return result
def _setup_prototype(self, *args, **kwargs):
super()._setup_prototype(*args, **kwargs)
self._event_dims = {}
self.locs = PyroModule()
self.scales = PyroModule()
# Initialize guide params
for name, site in self.prototype_trace.iter_stochastic_nodes():
# Collect unconstrained event_dims, which may differ from constrained event_dims.
with helpful_support_errors(site):
init_loc = (
biject_to(site["fn"].support).inv(site["value"].detach()).detach()
)
event_dim = site["fn"].event_dim + init_loc.dim() - site["value"].dim()
self._event_dims[name] = event_dim
# If subsampling, repeat init_value to full size.
for frame in site["cond_indep_stack"]:
full_size = getattr(frame, "full_size", frame.size)
if full_size != frame.size:
dim = frame.dim - event_dim
init_loc = periodic_repeat(init_loc, full_size, dim).contiguous()
init_scale = torch.full_like(init_loc, self._init_scale)
# like in https://github.com/TyXe-BDL/TyXe/blob/master/tyxe/guides.py
if not self.freeze_locs or name.startswith("head_var"):
init_loc = PyroParam(init_loc, constraints.real, event_dim)
deep_setattr(self.locs, name, init_loc)
deep_setattr(
self.scales,
name,
PyroParam(init_scale, self.scale_constraint, event_dim),
)
if self.freeze_locs:
self.freeze_loc_params()
def freeze_loc_params(self):
for n, p in self.locs.named_parameters():
if not n.startswith("head_var"):
p.requires_grad = False
Thanks in advance!