Diagonal normal identification with guide

When running the SVI with the below guide, it is mentioned it is DiagNormal.

def guide(is_cont_africa, ruggedness, log_gdp):
    a_loc = pyro.param('a_loc', torch.tensor(0.))
    a_scale = pyro.param('a_scale', torch.tensor(1.),
    sigma_loc = pyro.param('sigma_loc', torch.tensor(1.),
    weights_loc = pyro.param('weights_loc', torch.randn(3))
    weights_scale = pyro.param('weights_scale', torch.ones(3),
    a = pyro.sample("a", dist.Normal(a_loc, a_scale))
    b_a = pyro.sample("bA", dist.Normal(weights_loc[0], weights_scale[0]))
    b_r = pyro.sample("bR", dist.Normal(weights_loc[1], weights_scale[1]))
    b_ar = pyro.sample("bAR", dist.Normal(weights_loc[2], weights_scale[2]))
    sigma = pyro.sample("sigma", dist.Normal(sigma_loc, torch.tensor(0.05)))
    mean = a + b_a * is_cont_africa + b_r * ruggedness + b_ar * is_cont_africa * ruggedness

Can someone please explain, how we came to the conclusion? I am new to the domain.

We say this is a “diagonal normal” guide because the learned posterior is independent across the different variables a, bA, bR, bAR, sigma, that is their posterior joint covariance is a diagonal matrix. Alternatively we could have encoded a multivariate normal guide, either automatically using AutoMultivariateNormal or manually encoding a covariance structure, e.g. here’s a simple non-diagonal guide over two variables a and b:

def guide():
    loc = pyro.param("loc", torch.zeros(2))
    scale_tril = pyro.param(
        "scale_tril", torch.eye(2), constraint=constraints.lower_cholesky
    a = pyro.sample("a", dist.Normal(loc[0], scale_tril[0,0]))
    b = pyro.sample(
        "b", dist.Normal(loc[1] + a * scale_tril[1,0], scale_tril[1,1])

whereas the diagonal guide had no posterior dependencies between random variables, this guide has b depending on a via a * scale_tril[1,0].

More generally guides can have non-Gaussian posteriors with arbitrary dependency.

this is a great explanation. and thanks for sharing the non-diagonal code as well. Just to clarify something, it would have been non-diagonal even if the term for b would have been the following:

b = pyro.sample(
        "b", dist.Normal(loc[1] + a , scale_tril[1,1])

Is this understanding correct?

That’s correct, that would also have been non-diagonal, just not learnably multivariate. Actually it’s pretty common to use such tricks in guides or even reparametrizing models, e.g. in non-centering transforms where a local variable is known to have prior mean equal to some global variable.