Hi, I am in the process of learning Pyro. First of all, great extensions of pytorch, I love the API already.
I am trying to make a comparison with an easy example I can compute w/ numpy. So I imagine I have observed two data points length := [195, 185]
In pyMC3 my model looks as follows:
with pm.Model():
# priors
mu = pm.Normal('mu', mu=200, sd=15)
sigma = pm.HalfCauchy('sigma', 10)
# likelihood
observed = pm.Normal('observed', mu=mu, sd=sigma, observed=lengths).
# sample
trace = pm.sample(draws=1000, chains=1)
This model samples good in PyMC3. Now I want to make a comparision with Variational Inference.
My pyro code:
def model(lengths):
# priors
mu = pyro.sample('mu', dist.Normal(loc=torch.tensor(200.), scale=torch.tensor(15.)))
sigma = pyro.sample('sigma', dist.HalfCauchy(scale=torch.tensor(10.)))
for i in pyro.plate('plate', size=2):
pyro.sample(f'obs_{i}', dist.Normal(loc=mu, scale=sigma), obs=lengths[i])
def guide(lengths):
# posterior mu
mu_normal_mu = pyro.param('mu_normal_mu', torch.tensor(200.))
mu_normal_sigma = pyro.param('mu_normal_sigma', torch.tensor(5.),
constraint=constraints.positive)
mu = pyro.sample('mu', dist.Normal(loc=mu_normal_mu, scale=mu_normal_sigma))
# posterior sigma
sigma_halfnormal_sigma = pyro.param('sigma_halfnormal_sigma', torch.tensor(5.),
constraint=constraints.positive)
sigma = pyro.sample('sigma', dist.HalfNormal(sigma_halfnormal_sigma))
However, If I run inference on the model and the guide, I see that the sigma variances increase until numeric overflow occurs. What am I doing wrong here?
For completeness, the inference snippet
pyro.clear_param_store()
pyro.enable_validation(True)
# setup the inference algorithm
svi = SVI(model, guide,
optim=pyro.optim.SGD({"lr":0.001}),
loss=Trace_ELBO())
# do gradient steps
c = 0
for step in range(5000):
c += 1
loss = svi.step(torch.tensor((lengths - lengths.min()) / (lengths.max() - lengths.min()), dtype=torch.float32))
if step % 100 == 0:
print("[iteration %04d] loss: %.4f" % (c + 1, loss))