Pyro Parameters not updating in SVI

Hi,

I’m trying to implement Bayesian Negative Binomial Regression using pyro’s SVI. Right now, it looks like the parameters related to w are not updating after the first iteration, and the loss stagnates. I’ve read previous posts but I still cannot figure it out.

Am I setting up my model wrong, or is this a matter of tuning hyperparameters?

pyro.clear_param_store()
n_steps = 100000

def model(X, y):
    
    alpha = pyro.sample("alpha", dist.Gamma(1.0, 1.0))

    w = pyro.sample("w", dist.Normal(torch.zeros(4), torch.ones(4)))

    lambdas = torch.exp(torch.mv(X, w.clone().detach()))

    eps = 10e-5
    p = lambdas/(lambdas + 1)
    p = torch.clamp(p, eps, 1 - eps)
    return pyro.sample("y", 
                dist.NegativeBinomial(alpha, probs=p), 
                obs = y)

def guide(X, y):
    w_prior_mean_q = pyro.param("w_prior_mean_q", torch.zeros(4)) #, constraint=constraints.positive)
    w_prior_cov_q = pyro.param("w_prior_cov_q", torch.ones(4), constraint=constraints.positive)
    a_q = pyro.param("a_q", torch.tensor(3.0),
                         constraint=constraints.positive)
    b_q = pyro.param("b_q", torch.tensor(1.0),
                         constraint=constraints.positive)

    alpha = pyro.sample("alpha", dist.Normal(a_q, b_q)) 
    
    w = pyro.sample("w", dist.Normal(w_prior_mean_q, w_prior_cov_q))

# setup the optimizer
adam_params = {"lr": .01, "betas": (0.9, 0.999)}
optimizer = Adam(adam_params)

# setup the inference algorithm
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

params = []
# do gradient steps
loss = 0
for step in range(n_steps):
    loss += svi.step(train_X.clone().detach(), train_y)
    if step % (n_steps/100) == 0:
        params.append([pyro.param('a_q'),
                       pyro.param('b_q'),
                       pyro.param('w_prior_mean_q'), 
                       pyro.param('w_prior_cov_q')])
        print(100 * step/n_steps,
              loss / (n_steps/100)
             )
        loss = 0```

Hey @abucquet,

I don’t know if this is the solution to your problem, but I think this

w = pyro.sample("w", dist.Normal(pyro.param("w_prior_mean_q"), w_prior_cov_q))

should probably be this

w = pyro.sample("w", dist.Normal(w_prior_mean_q, w_prior_cov_q))

Might be a start.

All the best,
Scipio

Hi @scipio,

Thanks for pointing this out! I just edited it, but I still have the same issue…

why are you detaching gradients here? no information is being propagated back to your parameters

1 Like

Oooops, you’re right, that completely eluded me. Thanks for your help!