# SVI using MultivariateNormal

I tried making a minimal example of doing SVI with a `MultivariateNormal` distribution. I’m stuck.

First off, here is a minimal example with the univariate `Normal`. It’s working.

``````def model(x3_obs):
loc_prior = tensor(1.)
with pyro.plate('mini_batch',len(x3_obs)):
x1 = pyro.sample('x1',dist.Normal(loc_prior,1))
x2 = pyro.sample('x2',dist.Normal(x1,1))
x3 = pyro.sample('x3',dist.Normal(x2,1), obs=x3_obs)
return x1, x2, x3

x3_obs = tensor([1., 1.1, 1.2,])

def guide(x3_obs):
random_large_value = 10.0
x1_loc = pyro.param('x1_loc',tensor(random_large_value))
x2_loc = pyro.param('x2_loc',tensor(random_large_value))
x2_scale = pyro.param('x2_scale',tensor(random_large_value), constraint=constraints.positive)
with pyro.plate('mini_batch',len(x3_obs)):
x1 = pyro.sample('x1',dist.Normal(x1_loc,1))
x2 = pyro.sample('x2',dist.Normal(x2_loc,x2_scale))

pyro.clear_param_store()

svi = pyro.infer.SVI(model, guide,
loss=pyro.infer.Trace_ELBO(),
)

n_steps = 1000
for step in range(n_steps):
svi.step(x3_obs)

print('{:1.3f}, {:1.3f}, {:1.3f}'.format(pyro.param('x1_loc').item(), pyro.param('x2_loc').round().item(),  pyro.param('x2_scale').round().item()))
# 0.977, 1.000, 1.000
``````

In the example below, I expected the `pyro.param('loc')` to change from `(10,10)` to something around `(1,1)`. What is going on?

``````def model(data,n_batch):
with pyro.plate('mini_batch',n_batch):
loc_prior = tensor(1.)
x1 = pyro.sample('x1',dist.Normal(loc_prior,1))
x2 = pyro.sample('x2',dist.Normal(x1,1))
loc = torch.stack([x1,x2]).T
x3 = pyro.sample('x3',dist.MultivariateNormal(loc,scale_tril=torch.eye(2)), obs=data)
return x1, x2, x3

x1_gt, x2_gt, x3_gt = model(None,n_batch=100)
data = x3_gt

pyro.clear_param_store()

def guide(data, n_batch):
loc = pyro.param('loc',tensor([10,10.]))
with pyro.plate('mini_batch',n_batch):
x3 = pyro.sample('x3',dist.MultivariateNormal(loc,scale_tril=torch.eye(2)))

svi = pyro.infer.SVI(model, guide,
loss=pyro.infer.Trace_ELBO(),
)

n_steps = 1000
for step in range(n_steps):
svi.step(data,n_batch=100)

print(pyro.param('loc'))

``````

Hmmm, I understand why this is different than the univariate case. I am observing `x3` in the model, and don’t have sample statements for `x1` and `x2` in the guide.

``````def model(data,n_batch):
loc_prior = tensor([10.,20.])
with pyro.plate('mini_batch',n_batch):
x12 = pyro.sample('x12',dist.MultivariateNormal(loc_prior,torch.eye(2)))
x3 = pyro.sample('x3',dist.MultivariateNormal(x12,torch.eye(2)), obs=data)
return x12, x3

x12_gt, x3_gt = model(None,n_batch=100)
data = x3_gt

pyro.clear_param_store()

def guide(data,n_batch):
assert n_batch == len(data)
random_large_value_1, random_large_value_2 = 5.0, 10.0
x12_loc = pyro.param('x12_loc',tensor([random_large_value_1,random_large_value_1]))
x12_cov = pyro.param('x12_cov',random_large_value_2*torch.eye(2), constraint=constraints.positive_definite)
with pyro.plate('mini_batch',n_batch):
x12 = pyro.sample('x12',dist.MultivariateNormal(x12_loc,x12_cov))

svi = pyro.infer.SVI(model, guide,