How to pretend a Normal is a Multi-variate Gaussian for MeanFieldElbo?

Hello!

I was hoping to get some help understanding using to_event(1), particularly in the Gaussian case. Does calling to_event(1) on some Gaussian distribution make it act like a MultivariateGaussian still when we use the TraceMeanFieldELBO? This was my understanding based on the Pyro tutorials on dimensions.

I’m asking because I was having some HUGE discrepancies when I was extracting the KL term in my ELBO with the suggestions in (Extract the KL divergence term from loss. How? - #4 by martinjankowiak) compared to explicitly calculating it from my parameters.

my understanding is if I call .to_event(1) on a Gaussian that, at a higher level, this signal to Pyro to pretend “hey, this is actually a Multi-gaussian distribution” So my probabalistic functions are like this:

#necessary imports
def model(args**):
     #doing over a minibatch of my args
     poutine.plate("data"):
          mus, sigmas = BlackBoxMagic(args**)
          distr = Gaussian(mus, sigmas).toevent(1) #my pretend Multivariate Gauss
          pyro.sample("z", distr)
     #the rest of my fancy model

def guide(args**):
     #doing over a minibatch of my args
     poutine.plate("data"):
          mus, sigmas = OtherBlackBoxMagic(args**)
          distr = Gaussian(mus, sigmas).toevent(1) #my pretend Multivariate Gauss
          pyro.sample("z", distr)



optim = Adam({"lr": 0.001})
svi = SVI(model, guide, optim, TraceMeanField_ELBO())

svi.step(batch_o_data) #my ELBO is presumably going to use the KL for 2 multivariates??

when I actually step through the Trace_MeanFieldELBO and look at my sites which use this trick, I notice that internally, it looks like PyTorch is treating this as the regular KL between to Normals.

Any thoughts are appreciated!

As an alternative, how would I go about explicitly adding the the analytical KL calculation for 2 diagonal gaussian distributions? Like where in the trace would I access the distribution parameters such that the gradients would still propagate?

Hi @megaloman, pytorch treats them as independent distributions and compute kl based on this implementation. I think that is equivalent to KL of two diagonal MVN distributions. You can test it by comparing

kl_divergence(dist.Normal(loc1, scale1).to_event(1), dist.Normal(loc2, scale2).to_event(1))

and

kl_divergence(dist.MultivariateNormal(loc1, scale1.pow(2).diag()), dist.MultivariateNormal(loc2, scale2.pow(2).diag()))

how would I go about explicitly adding the the analytical KL calculation for 2 diagonal gaussian distributions

When you have a guide trace and model trace, you can compute KL similar to what is implemented in TraceMeanFieldELBO. The site_dist = trace.nodes[site_name]["fn"] will give you the distribution at site_name, there you can access loc, scale parameters through site_dist.loc, site_dist.scale.

1 Like

Thanks for the tips!

I’ll give it another look, but part of why I raised this issue was because it seemed like when I tried something similar to what you suggested, the KL terms were different in the the case of Normal(…).to_event(1) case vs MultivariateNormal(…) case.

If they end up actually not being the same, is that a bug in Pyro/ Pytorch? Like obviously a bug on my end is quite more likely, but it would be good to clarify.

I think that if they are not the same, then it is likely a bug in PyTorch distributions module. Could you provide an example for an inconsistency in kl_divergence computation?

1 Like

Whew, so I played with your suggestion a bit, and now I see the mistake I was making. When I was calling the MultiVariateGaussian I wasn’t calling the diagonal terms to the power of 2 like in your example. Without that little function call on the scale terms, I was getting the same discrepancy I was seeing in my code. With it, like magic, both the values were the same whether Normal(…).to_event(1) or MultivariateGaussian.

Thank you so much for the help!

1 Like

Glad that it helps! I made the same mistake previously but in the reverse order: to get scale parameter for Normal distribution, I forgot to apply sqrt after taking variance from MVN. :smiley:

1 Like