I was explaining how I’m using Pyro’s SVI in a presentation to an audience more familiar with VAEs. To make things clearer, I’d like to explicitly include a mathy loss term used by Pyro under the hood in my slides and make parallels with the “reconstruction term” and “regularization term” in popular treatments of typical VAEs.
I’m using the pyro.infer.Trace_ELBO
loss. I’d like to connect the Pyro source pyro.infer.trace_elbo — Pyro documentation and how SVI works in Pyro to the equations in the referenced paper in the Pyro documentation. Specifically equation 1 in the BBVI paper [1].
In equation 1 we
- sample z ~ q_lam(z). This happens in the guide and is passed to the model
- evaluate log p(x,z). Let’s decompose p(x,z) = p(x|z)p(z).
- p(z) would give a logprob for the prior of z: the sample is given by q in the guide and the logprob comes from the prior p, specified in the model. When VAEs are discussed this is referred to as the KL regularization term in this blog on VAEs (which assumes the sampling in the guide/encoder is Normal and that the prior in the model is N(0,1).
- p(x|z) would be the observe statement in the model. My model has an observe statement representing noise of a
clean_image
compared to the noisy data:pyro.sample('label', distribution_noise(clean_image), obs=data)
. Therefore log p(x,z) would be the “reconstruction loss” under the pdf of thedistribution_noise
. In the case of a Gaussian distribution N(clean_image,sigma), this would be the L2 loss: -|clean_image - data|^2 / (2*sigma^2)
- evaluate log q(z). This is the likelihood of the latents, under the distribution governing them, which are specified in the guide and which I specify in the guide. Each
pyro.sample
statement in the guide contributes to this term.
[1] Black Box Variational Inference,
Rajesh Ranganath, Sean Gerrish, David M. Blei