SVI not taking data into account and focussing only on the Prior when using a nonlinear function

  • What tutorial are you running?
    I am extending the tutorial “An introduction to Inference in Pyro”, by introducing a non-linear function. Tutorial: (DEPRECATED) An Introduction to Inference in Pyro — Pyro Tutorials 1.8.4 documentation

  • Please link or paste relevant code, and steps to reproduce.
    After introducing a non-linear function the SVI results do not seem to take the data into account anymore and focus only on the prior. I have set the data to a significantly different value than the prior but the results do not change.

Does someone have any idea where I went wrong?

Here is the code from the tutorial with some minor changes:

def some_nonlinear_function(weight):
    return 1/weight

def intractable_scale(guess, data):
    weight = pyro.sample("weight", dist.Normal(guess, 1.0))
    return pyro.sample("measurement", dist.Normal(some_nonlinear_function(weight), 0.75), obs=data)

def scale_parametrized_guide_constrained(guess, data):
    a = pyro.param("a", torch.tensor(guess))
    b = pyro.param("b", torch.tensor(1.), constraint=constraints.positive)
    return pyro.sample("weight", dist.Normal(a, b))  

guess = 10.
data = some_nonlinear_function(3.5) # 3.5 is really different from 10 

optim = pyro.optim.Adam({'lr': 2e-2})
optim=pyro.optim.SGD({"lr": 0.001, "momentum":0.1})

pyro.clear_param_store()
svi = pyro.infer.SVI(model=intractable_scale, 
                     guide=scale_parametrized_guide_constrained,
                     optim = optim,
                     loss=pyro.infer.Trace_ELBO(num_particles=1,retain_graph=True))


losses, a,b  = [], [], []
num_steps = 2500
for t in range(num_steps):
    loss = svi.step(guess, data)
    losses.append(loss)
    if not t % 100:
        print('loss: ', t, ' ', loss)
    a.append(pyro.param("a").item())
    b.append(pyro.param("b").item())

plt.figure()
plt.subplot(1,2,1)
plt.plot([0,num_steps],[9.14,9.14], 'k:')
plt.plot(a)
plt.ylabel('a')
plt.show()

plt.subplot(1,2,2)
plt.ylabel('b')
plt.plot([0,num_steps],[0.6,0.6], 'k:')
plt.plot(b)
plt.tight_layout()
plt.show()

I just figured out that probably my variance inside the scale was set too big, so I think that this was the problem.

2 Likes