Pyro.infer.trace_elbo

I was explaining how I’m using Pyro’s SVI in a presentation to an audience more familiar with VAEs. To make things clearer, I’d like to explicitly include a mathy loss term used by Pyro under the hood in my slides and make parallels with the “reconstruction term” and “regularization term” in popular treatments of typical VAEs.

I’m using the pyro.infer.Trace_ELBO loss. I’d like to connect the Pyro source pyro.infer.trace_elbo — Pyro documentation and how SVI works in Pyro to the equations in the referenced paper in the Pyro documentation. Specifically equation 1 in the BBVI paper [1].

In equation 1 we

  • sample z ~ q_lam(z). This happens in the guide and is passed to the model
  • evaluate log p(x,z). Let’s decompose p(x,z) = p(x|z)p(z).
    • p(z) would give a logprob for the prior of z: the sample is given by q in the guide and the logprob comes from the prior p, specified in the model. When VAEs are discussed this is referred to as the KL regularization term in this blog on VAEs (which assumes the sampling in the guide/encoder is Normal and that the prior in the model is N(0,1).
    • p(x|z) would be the observe statement in the model. My model has an observe statement representing noise of a clean_image compared to the noisy data: pyro.sample('label', distribution_noise(clean_image), obs=data). Therefore log p(x,z) would be the “reconstruction loss” under the pdf of the distribution_noise. In the case of a Gaussian distribution N(clean_image,sigma), this would be the L2 loss: -|clean_image - data|^2 / (2*sigma^2)
  • evaluate log q(z). This is the likelihood of the latents, under the distribution governing them, which are specified in the guide and which I specify in the guide. Each pyro.sample statement in the guide contributes to this term.

[1] Black Box Variational Inference,
Rajesh Ranganath, Sean Gerrish, David M. Blei

is there a question in there?

note that if you have reparemeterizable latent variables (e.g. normal) pyro does not use “black box variational inference” (which makes use of score function gradient estimators).

it’s quite a bit to read through but i think all of this is explained reasonably well in svi parts i-iii

My question is, “am I understanding this correctly” according to how I outlined my understanding.

Thanks the the links, I’ll re-read the SVI parts :slight_smile: