Hi there!
Assume I have a model with a couple of latent variables (see below). So far I have been using an AutoNormal
guide. But I know that a few LVs are correlated, so I want to use an AutoMultivariateNormal
guide instead.
However, is there a way to sepcify which subset of the latent variables are correlated when defining the AutoMultivariateNormal
guide?
def model():
a1 = pyro.sample("a1", dist.HalfCauchy(torch.ones(1)))
a2 = pyro.sample("a2", dist.Normal(0, torch.ones(1)))
a3 = pyro.sample("a3", dist.Normal(0, torch.ones(1)))
a4 = pyro.sample("a4", dist.Gamma(torch.ones(1), torch.ones(1)))
a5 = pyro.sample("a5", dist.Uniform(0, 10*torch.ones(1)))
In this example, assume I know that the Mean-Field Assumption (aka independence) holds for a1
and a2
, BUT a3
, a4
and a5
are correlated.