I am learning Pyro and am bumping up against the same problem as SVI doesn't change parameters in that parameters don’t change and the ELBO fluctuates randomly. However, I have actually declared params in my guide. My use case is involved enough that I do not want to use an AutoGuide (sp?), which is why I am trying to write one myself. Here is a small reprex:
import os
import torch
from torch.distributions import constraints
import numpy as np
import pandas as pd
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO, Predictive
def model(predictors, pheno):
n_predictors = predictors.shape[1]
gamma_0 = pyro.sample("g0", dist.Normal(0., 1.))
covs = pyro.sample("covs", dist.Normal(torch.zeros(n_predictors), torch.ones(n_predictors)).independent(1))
l = (gamma_0 + torch.matmul(predictors, covs.squeeze())).squeeze(-1)
with pyro.plate("pheno", len(pheno)):
pyro.sample("obs", dist.Bernoulli(logits=l), obs=pheno)
def guide(predictors, pheno):
n_predictors = predictors.shape[1]
g0_loc = pyro.param("g0_loc", torch.tensor(0.))
g0_scale = pyro.param("g0_scale", torch.tensor(1.), constraint=constraints.positive)
gamma_0 = pyro.sample("g0", dist.Normal(g0_loc, g0_scale))
covs_loc = pyro.param("covs_loc", torch.zeros(n_predictors))
covs = pyro.sample("covs", dist.Normal(covs_loc, 1.).independent(1))
torch.manual_seed(1)
predictors = torch.rand([100, 3])
coefs = torch.tensor([1., -5., 0.2])
pheno = torch.bernoulli(torch.sigmoid(0.2 + torch.matmul(predictors, coefs)))
pyro.clear_param_store()
optim = pyro.optim.Adam({"lr": 0.001})
svi = SVI(model, guide, optim, loss=Trace_ELBO())
for i in range(1000):
loss = svi.step(predictors, pheno)
if i % 100 == 0:
print(loss)
predictive = Predictive(model, guide=guide, num_samples=100)
samples = predictive(predictors, pheno)
I feel that I am missing something important but simple - I should be able to use Pyro to fit this regression model, right?
Also, I parameterized covs
as a matrix of independent parameters because in my application, this would be a set of nuisance covariates one has to include, but not sure if this is the best way to do this in Pyro.