Hi all,
I’m trying to replicate the Numpyro hierarchical regression tutorial by using VI instead of MCMC using Pyro version: 1.8. I’m having trouble getting a good convergence. I’ve tried Autoguides (AutoNormal and AutoLowRankMultivariateNormal) and a custom guide. The custom guide, which I wrote to be mean field, doesn’t seem to converge at all or has very high variance gradients. AutoLowRankMultivariateNormal seems to have the best convergence but the solution is still way off when looking at the predictive distribution.
pyro.clear_param_store()
def model(PatientID, Weeks, FVC_obs=None):
μ_α = pyro.sample("μ_α", dist.Normal(0.0, 100.0))
σ_α = pyro.sample("σ_α", dist.HalfNormal(100.0))
μ_β = pyro.sample("μ_β", dist.Normal(0.0, 100.0))
σ_β = pyro.sample("σ_β", dist.HalfNormal(100.0))
unique_patient_IDs = np.unique(PatientID)
n_patients = len(unique_patient_IDs)
with pyro.plate("plate_i", n_patients):
α = pyro.sample("α", dist.Normal(μ_α, σ_α))
β = pyro.sample("β", dist.Normal(μ_β, σ_β))
σ = pyro.sample("σ", dist.HalfNormal(100.0))
FVC_est = α[PatientID] + β[PatientID] * Weeks
with pyro.plate("data", len(PatientID)):
pyro.sample("obs", dist.Normal(FVC_est, σ), obs=FVC_obs)
def guide(PatientID, Weeks, FVC_obs=None):
global_μ_α = pyro.param("global_μ_α", torch.tensor(0.0))
global_μ_β = pyro.param("global_μ_β", torch.tensor(0.0))
global_σ_α = pyro.param("global_σ_α", torch.tensor(0.1), constraint=constraints.positive)
global_σ_β = pyro.param("global_σ_β", torch.tensor(0.1), constraint=constraints.positive)
μ_α = pyro.sample("μ_α", dist.Normal(global_μ_α, global_σ_α))
μ_β = pyro.sample("μ_β", dist.Normal(global_μ_β, global_σ_β))
σ_α = pyro.sample("σ_α", dist.HalfNormal(global_σ_α))
σ_β = pyro.sample("σ_β", dist.HalfNormal(global_σ_β))
global_σ = pyro.param("global_σ", torch.tensor(0.1), constraint=constraints.positive)
σ = pyro.sample("σ", dist.HalfNormal(global_σ))
n_patients = len(np.unique(PatientID))
patient_mus = pyro.param("local_mus", torch.tensor(0.0).expand(2, n_patients))
patient_scales = pyro.param("local_scales",
torch.tensor(0.1).expand(2, n_patients),
constraint=constraints.positive)
with pyro.plate("plate_i", n_patients):
pyro.sample("α", dist.Normal(patient_mus[0], patient_scales[0]))
pyro.sample("β", dist.Normal(patient_mus[1], patient_scales[1]))
# Uncomment for different guides
# guide = AutoNormal(model)
# guide = AutoLowRankMultivariateNormal(model)
wandb.init(project="pyro-hierarchical-regression", entity="kenleejr")
wandb.config.learning_rate = 0.003
wandb.config.epochs = 2000
svi = pyro.infer.SVI(model=model,
guide=guide,
optim=pyro.optim.Adam({"lr": wandb.config.learning_rate}),
loss=pyro.infer.Trace_ELBO())
v_names = list(pyro.get_param_store().keys())
for t in range(wandb.config.epochs):
elbo = svi.step(PatientID, Weeks, FVC_obs)
elbo_dict = {"ELBO": elbo}
wandb.log(elbo_dict)
if t % 100 == 0:
v_params = {v: pyro.param(v).data.numpy() for v in v_names}
wandb.log(v_params)
Here is a link to my notebook: variational_inference/Pyro_Hierarchical_Regression.ipynb at main · kenleejr/variational_inference · GitHub. Attached are images of the parameters during the optimization loop.