SVI using MultivariateNormal

I tried making a minimal example of doing SVI with a MultivariateNormal distribution. I’m stuck.

First off, here is a minimal example with the univariate Normal. It’s working.

def model(x3_obs):
  loc_prior = tensor(1.)
  with pyro.plate('mini_batch',len(x3_obs)):
    x1 = pyro.sample('x1',dist.Normal(loc_prior,1))
    x2 = pyro.sample('x2',dist.Normal(x1,1))
    x3 = pyro.sample('x3',dist.Normal(x2,1), obs=x3_obs)
  return x1, x2, x3

x3_obs = tensor([1., 1.1, 1.2,])


def guide(x3_obs):
  random_large_value = 10.0
  x1_loc = pyro.param('x1_loc',tensor(random_large_value))
  x2_loc = pyro.param('x2_loc',tensor(random_large_value))
  x2_scale = pyro.param('x2_scale',tensor(random_large_value), constraint=constraints.positive)
  with pyro.plate('mini_batch',len(x3_obs)):
    x1 = pyro.sample('x1',dist.Normal(x1_loc,1))
    x2 = pyro.sample('x2',dist.Normal(x2_loc,x2_scale))

pyro.clear_param_store()

svi = pyro.infer.SVI(model, guide, 
                     optim=pyro.optim.Adam({"lr": 1e-1}), 
                     loss=pyro.infer.Trace_ELBO(),
)

n_steps = 1000
for step in range(n_steps):
    svi.step(x3_obs)

print('{:1.3f}, {:1.3f}, {:1.3f}'.format(pyro.param('x1_loc').item(), pyro.param('x2_loc').round().item(),  pyro.param('x2_scale').round().item()))
  # 0.977, 1.000, 1.000

In the example below, I expected the pyro.param('loc') to change from (10,10) to something around (1,1). What is going on?

def model(data,n_batch):
  with pyro.plate('mini_batch',n_batch):
    loc_prior = tensor(1.)
    x1 = pyro.sample('x1',dist.Normal(loc_prior,1))
    x2 = pyro.sample('x2',dist.Normal(x1,1))
    loc = torch.stack([x1,x2]).T
    x3 = pyro.sample('x3',dist.MultivariateNormal(loc,scale_tril=torch.eye(2)), obs=data)
  return x1, x2, x3

x1_gt, x2_gt, x3_gt = model(None,n_batch=100)
data = x3_gt


pyro.clear_param_store()

def guide(data, n_batch):
  loc = pyro.param('loc',tensor([10,10.]))
  with pyro.plate('mini_batch',n_batch):
    x3 = pyro.sample('x3',dist.MultivariateNormal(loc,scale_tril=torch.eye(2)))

svi = pyro.infer.SVI(model, guide, 
                     optim=pyro.optim.Adam({"lr": 1e-1}), 
                     loss=pyro.infer.Trace_ELBO(),
)

n_steps = 1000
for step in range(n_steps):
    svi.step(data,n_batch=100)

print(pyro.param('loc'))
  # tensor([10., 10.], requires_grad=True)

Hmmm, I understand why this is different than the univariate case. I am observing x3 in the model, and don’t have sample statements for x1 and x2 in the guide.

def model(data,n_batch):
  loc_prior = tensor([10.,20.])
  with pyro.plate('mini_batch',n_batch):
    x12 = pyro.sample('x12',dist.MultivariateNormal(loc_prior,torch.eye(2)))
    x3 = pyro.sample('x3',dist.MultivariateNormal(x12,torch.eye(2)), obs=data)
  return x12, x3

x12_gt, x3_gt = model(None,n_batch=100)
data = x3_gt

pyro.clear_param_store()

def guide(data,n_batch):
  assert n_batch == len(data)
  random_large_value_1, random_large_value_2 = 5.0, 10.0
  x12_loc = pyro.param('x12_loc',tensor([random_large_value_1,random_large_value_1]))
  x12_cov = pyro.param('x12_cov',random_large_value_2*torch.eye(2), constraint=constraints.positive_definite)
  with pyro.plate('mini_batch',n_batch):
    x12 = pyro.sample('x12',dist.MultivariateNormal(x12_loc,x12_cov))

svi = pyro.infer.SVI(model, guide, 
                     optim=pyro.optim.Adam({"lr": 1e-1}), 
                     loss=pyro.infer.Trace_ELBO(),
)

print(pyro.param('x12_loc'))
  # tensor([0.8735, 1.9312], requires_grad=True)
print(pyro.param('x12_cov'))
  # tensor([[0.4642, 0.0153],[0.0153, 0.4813]], grad_fn=<MmBackward0>)