I tried making a minimal example of doing SVI with a MultivariateNormal
distribution. I’m stuck.
First off, here is a minimal example with the univariate Normal
. It’s working.
def model(x3_obs):
loc_prior = tensor(1.)
with pyro.plate('mini_batch',len(x3_obs)):
x1 = pyro.sample('x1',dist.Normal(loc_prior,1))
x2 = pyro.sample('x2',dist.Normal(x1,1))
x3 = pyro.sample('x3',dist.Normal(x2,1), obs=x3_obs)
return x1, x2, x3
x3_obs = tensor([1., 1.1, 1.2,])
def guide(x3_obs):
random_large_value = 10.0
x1_loc = pyro.param('x1_loc',tensor(random_large_value))
x2_loc = pyro.param('x2_loc',tensor(random_large_value))
x2_scale = pyro.param('x2_scale',tensor(random_large_value), constraint=constraints.positive)
with pyro.plate('mini_batch',len(x3_obs)):
x1 = pyro.sample('x1',dist.Normal(x1_loc,1))
x2 = pyro.sample('x2',dist.Normal(x2_loc,x2_scale))
pyro.clear_param_store()
svi = pyro.infer.SVI(model, guide,
optim=pyro.optim.Adam({"lr": 1e-1}),
loss=pyro.infer.Trace_ELBO(),
)
n_steps = 1000
for step in range(n_steps):
svi.step(x3_obs)
print('{:1.3f}, {:1.3f}, {:1.3f}'.format(pyro.param('x1_loc').item(), pyro.param('x2_loc').round().item(), pyro.param('x2_scale').round().item()))
# 0.977, 1.000, 1.000
In the example below, I expected the pyro.param('loc')
to change from (10,10)
to something around (1,1)
. What is going on?
def model(data,n_batch):
with pyro.plate('mini_batch',n_batch):
loc_prior = tensor(1.)
x1 = pyro.sample('x1',dist.Normal(loc_prior,1))
x2 = pyro.sample('x2',dist.Normal(x1,1))
loc = torch.stack([x1,x2]).T
x3 = pyro.sample('x3',dist.MultivariateNormal(loc,scale_tril=torch.eye(2)), obs=data)
return x1, x2, x3
x1_gt, x2_gt, x3_gt = model(None,n_batch=100)
data = x3_gt
pyro.clear_param_store()
def guide(data, n_batch):
loc = pyro.param('loc',tensor([10,10.]))
with pyro.plate('mini_batch',n_batch):
x3 = pyro.sample('x3',dist.MultivariateNormal(loc,scale_tril=torch.eye(2)))
svi = pyro.infer.SVI(model, guide,
optim=pyro.optim.Adam({"lr": 1e-1}),
loss=pyro.infer.Trace_ELBO(),
)
n_steps = 1000
for step in range(n_steps):
svi.step(data,n_batch=100)
print(pyro.param('loc'))
# tensor([10., 10.], requires_grad=True)