AutoMultivariateNormal Guide on subset of Latent Variables

Hi there!

Assume I have a model with a couple of latent variables (see below). So far I have been using an AutoNormal guide. But I know that a few LVs are correlated, so I want to use an AutoMultivariateNormal guide instead.
However, is there a way to sepcify which subset of the latent variables are correlated when defining the AutoMultivariateNormal guide?

def model():
    a1 = pyro.sample("a1", dist.HalfCauchy(torch.ones(1)))
    a2 = pyro.sample("a2", dist.Normal(0, torch.ones(1)))
    a3 = pyro.sample("a3", dist.Normal(0, torch.ones(1)))
    a4 = pyro.sample("a4", dist.Gamma(torch.ones(1), torch.ones(1)))
    a5 = pyro.sample("a5", dist.Uniform(0, 10*torch.ones(1)))

In this example, assume I know that the Mean-Field Assumption (aka independence) holds for a1 and a2, BUT a3, a4 and a5 are correlated.

please refer to AutoGuideList

1 Like

Thank you very much! Somehow I missed it.