 # How to pretend a Normal is a Multi-variate Gaussian for MeanFieldElbo?

Hello!

I was hoping to get some help understanding using to_event(1), particularly in the Gaussian case. Does calling to_event(1) on some Gaussian distribution make it act like a MultivariateGaussian still when we use the TraceMeanFieldELBO? This was my understanding based on the Pyro tutorials on dimensions.

I’m asking because I was having some HUGE discrepancies when I was extracting the KL term in my ELBO with the suggestions in (Extract the KL divergence term from loss. How?) compared to explicitly calculating it from my parameters.

my understanding is if I call .to_event(1) on a Gaussian that, at a higher level, this signal to Pyro to pretend “hey, this is actually a Multi-gaussian distribution” So my probabalistic functions are like this:

``````#necessary imports
def model(args**):
#doing over a minibatch of my args
poutine.plate("data"):
mus, sigmas = BlackBoxMagic(args**)
distr = Gaussian(mus, sigmas).toevent(1) #my pretend Multivariate Gauss
pyro.sample("z", distr)
#the rest of my fancy model

def guide(args**):
#doing over a minibatch of my args
poutine.plate("data"):
mus, sigmas = OtherBlackBoxMagic(args**)
distr = Gaussian(mus, sigmas).toevent(1) #my pretend Multivariate Gauss
pyro.sample("z", distr)

optim = Adam({"lr": 0.001})
svi = SVI(model, guide, optim, TraceMeanField_ELBO())

svi.step(batch_o_data) #my ELBO is presumably going to use the KL for 2 multivariates??
``````

when I actually step through the Trace_MeanFieldELBO and look at my sites which use this trick, I notice that internally, it looks like PyTorch is treating this as the regular KL between to Normals.

Any thoughts are appreciated!

As an alternative, how would I go about explicitly adding the the analytical KL calculation for 2 diagonal gaussian distributions? Like where in the trace would I access the distribution parameters such that the gradients would still propagate?

Hi @megaloman, pytorch treats them as independent distributions and compute kl based on this implementation. I think that is equivalent to KL of two diagonal MVN distributions. You can test it by comparing

``````kl_divergence(dist.Normal(loc1, scale1).to_event(1), dist.Normal(loc2, scale2).to_event(1))
``````

and

``````kl_divergence(dist.MultivariateNormal(loc1, scale1.pow(2).diag()), dist.MultivariateNormal(loc2, scale2.pow(2).diag()))
``````

how would I go about explicitly adding the the analytical KL calculation for 2 diagonal gaussian distributions

When you have a guide trace and model trace, you can compute KL similar to what is implemented in TraceMeanFieldELBO. The `site_dist = trace.nodes[site_name]["fn"]` will give you the distribution at `site_name`, there you can access `loc`, `scale` parameters through `site_dist.loc`, `site_dist.scale`.

1 Like

Thanks for the tips!

I’ll give it another look, but part of why I raised this issue was because it seemed like when I tried something similar to what you suggested, the KL terms were different in the the case of Normal(…).to_event(1) case vs MultivariateNormal(…) case.

If they end up actually not being the same, is that a bug in Pyro/ Pytorch? Like obviously a bug on my end is quite more likely, but it would be good to clarify.

I think that if they are not the same, then it is likely a bug in PyTorch distributions module. Could you provide an example for an inconsistency in `kl_divergence` computation?

1 Like

Whew, so I played with your suggestion a bit, and now I see the mistake I was making. When I was calling the MultiVariateGaussian I wasn’t calling the diagonal terms to the power of 2 like in your example. Without that little function call on the scale terms, I was getting the same discrepancy I was seeing in my code. With it, like magic, both the values were the same whether Normal(…).to_event(1) or MultivariateGaussian.

Thank you so much for the help!

1 Like

Glad that it helps! I made the same mistake previously but in the reverse order: to get `scale` parameter for Normal distribution, I forgot to apply `sqrt` after taking variance from MVN. 1 Like