Modeling a 2D Multivariate Gaussian Distribution

Hey all,

I am new to Pyro, and for learning purposes I want to fit a 2D Multivariate Gaussian distribution, but I get the following error: “ValueError: at site “mu”, invalid log_prob shape Expected [], actual [2]”

What do I have to fix in the model function?

Here is the code:

import torch
import pyro
import pyro.distributions as dist
import pyro.optim as optim
from pyro.infer import SVI, Trace_ELBO

torch.set_rng_state(torch.manual_seed(1).get_state())
true_mu = torch.Tensor([0, 0])
true_sigma = torch.Tensor([1, 2])
true_cov = torch.Tensor([0.75])

actual_data = pyro.distributions.MultivariateNormal(true_mu, torch.Tensor([[true_sigma[0], true_cov[0]], [true_cov[0], true_sigma[1]]]))
samples = actual_data.sample(torch.Size([20]))

# create a pyro model that samples from a 2D gaussian
def model(data):
    mu = pyro.sample("mu", dist.Normal(torch.zeros(2), torch.ones(2)))
    sigma = pyro.sample("sigma", dist.LogNormal(torch.zeros(2), torch.ones(2)))
    rho = pyro.sample("rho", dist.LKJ(2, 2)) 
    cov = torch.outer(sigma, sigma) * rho
    with pyro.plate("data", len(data)):
        pyro.sample("obs", dist.MultivariateNormal(mu, cov), obs=data)

auto_guide = pyro.infer.autoguide.AutoMultivariateNormal(model)

svi = SVI(model, 
          auto_guide, 
          optim.Adam({"lr": .05}), 
          loss=Trace_ELBO())

pyro.clear_param_store()
num_iters = 2000
for i in range(num_iters):
    elbo = svi.step(samples)

# Expected [], actual [2]

Check this the docs or this link for distribution shapes. What you have when you do mu is two-dim vector but you have to specify how are you treating the values.

The simplest is to assume that the dimensions are conditionally independent. Then use:

with pyro.plate('dims',2):
        mu = pyro.sample("mu", dist.Normal(torch.zeros(2), torch.ones(2)))
        sigma = pyro.sample("sigma", dist.LogNormal(torch.zeros(2), torch.ones(2)))

Or you can use to_event(1) like this:

mu = pyro.sample("mu", dist.Normal(torch.zeros(2), torch.ones(2)).to_event(1))
sigma = pyro.sample("sigma", dist.LogNormal(torch.zeros(2), torch.ones(2)).to_event(1))

Check the above links to get the idea about distribution shapes and decide how you want to proceed.

Hope this helps.