-
What tutorial are you running?
I am extending the tutorial “An introduction to Inference in Pyro”, by introducing a non-linear function. Tutorial: (DEPRECATED) An Introduction to Inference in Pyro — Pyro Tutorials 1.8.4 documentation -
Please link or paste relevant code, and steps to reproduce.
After introducing a non-linear function the SVI results do not seem to take the data into account anymore and focus only on the prior. I have set the data to a significantly different value than the prior but the results do not change.
Does someone have any idea where I went wrong?
Here is the code from the tutorial with some minor changes:
def some_nonlinear_function(weight):
return 1/weight
def intractable_scale(guess, data):
weight = pyro.sample("weight", dist.Normal(guess, 1.0))
return pyro.sample("measurement", dist.Normal(some_nonlinear_function(weight), 0.75), obs=data)
def scale_parametrized_guide_constrained(guess, data):
a = pyro.param("a", torch.tensor(guess))
b = pyro.param("b", torch.tensor(1.), constraint=constraints.positive)
return pyro.sample("weight", dist.Normal(a, b))
guess = 10.
data = some_nonlinear_function(3.5) # 3.5 is really different from 10
optim = pyro.optim.Adam({'lr': 2e-2})
optim=pyro.optim.SGD({"lr": 0.001, "momentum":0.1})
pyro.clear_param_store()
svi = pyro.infer.SVI(model=intractable_scale,
guide=scale_parametrized_guide_constrained,
optim = optim,
loss=pyro.infer.Trace_ELBO(num_particles=1,retain_graph=True))
losses, a,b = [], [], []
num_steps = 2500
for t in range(num_steps):
loss = svi.step(guess, data)
losses.append(loss)
if not t % 100:
print('loss: ', t, ' ', loss)
a.append(pyro.param("a").item())
b.append(pyro.param("b").item())
plt.figure()
plt.subplot(1,2,1)
plt.plot([0,num_steps],[9.14,9.14], 'k:')
plt.plot(a)
plt.ylabel('a')
plt.show()
plt.subplot(1,2,2)
plt.ylabel('b')
plt.plot([0,num_steps],[0.6,0.6], 'k:')
plt.plot(b)
plt.tight_layout()
plt.show()