Hierarchical Regression with VI

Hi all,

I’m trying to replicate the Numpyro hierarchical regression tutorial by using VI instead of MCMC using Pyro version: 1.8. I’m having trouble getting a good convergence. I’ve tried Autoguides (AutoNormal and AutoLowRankMultivariateNormal) and a custom guide. The custom guide, which I wrote to be mean field, doesn’t seem to converge at all or has very high variance gradients. AutoLowRankMultivariateNormal seems to have the best convergence but the solution is still way off when looking at the predictive distribution.

pyro.clear_param_store()
def model(PatientID, Weeks, FVC_obs=None):
    μ_α = pyro.sample("μ_α", dist.Normal(0.0, 100.0))
    σ_α = pyro.sample("σ_α", dist.HalfNormal(100.0))
    μ_β = pyro.sample("μ_β", dist.Normal(0.0, 100.0))
    σ_β = pyro.sample("σ_β", dist.HalfNormal(100.0))

    unique_patient_IDs = np.unique(PatientID)
    n_patients = len(unique_patient_IDs)

    with pyro.plate("plate_i", n_patients):
        α = pyro.sample("α", dist.Normal(μ_α, σ_α))
        β = pyro.sample("β", dist.Normal(μ_β, σ_β))

    σ = pyro.sample("σ", dist.HalfNormal(100.0))
    FVC_est = α[PatientID] + β[PatientID] * Weeks

    with pyro.plate("data", len(PatientID)):
        pyro.sample("obs", dist.Normal(FVC_est, σ), obs=FVC_obs)


def guide(PatientID, Weeks, FVC_obs=None):
    global_μ_α = pyro.param("global_μ_α", torch.tensor(0.0))
    global_μ_β = pyro.param("global_μ_β", torch.tensor(0.0))
    global_σ_α = pyro.param("global_σ_α", torch.tensor(0.1), constraint=constraints.positive)
    global_σ_β = pyro.param("global_σ_β", torch.tensor(0.1), constraint=constraints.positive)

    μ_α = pyro.sample("μ_α", dist.Normal(global_μ_α, global_σ_α))
    μ_β = pyro.sample("μ_β", dist.Normal(global_μ_β, global_σ_β))
    σ_α = pyro.sample("σ_α", dist.HalfNormal(global_σ_α))
    σ_β = pyro.sample("σ_β", dist.HalfNormal(global_σ_β))
    
    global_σ = pyro.param("global_σ", torch.tensor(0.1), constraint=constraints.positive)
    σ = pyro.sample("σ", dist.HalfNormal(global_σ))

    n_patients = len(np.unique(PatientID))

    patient_mus = pyro.param("local_mus", torch.tensor(0.0).expand(2, n_patients))
    patient_scales = pyro.param("local_scales",
                                torch.tensor(0.1).expand(2, n_patients),
                                constraint=constraints.positive)

    with pyro.plate("plate_i", n_patients):
        pyro.sample("α", dist.Normal(patient_mus[0], patient_scales[0]))
        pyro.sample("β", dist.Normal(patient_mus[1], patient_scales[1]))

# Uncomment for different guides
# guide = AutoNormal(model)
# guide = AutoLowRankMultivariateNormal(model)

wandb.init(project="pyro-hierarchical-regression", entity="kenleejr")

wandb.config.learning_rate = 0.003
wandb.config.epochs = 2000

svi = pyro.infer.SVI(model=model,
                     guide=guide,
                     optim=pyro.optim.Adam({"lr": wandb.config.learning_rate}),
                     loss=pyro.infer.Trace_ELBO())

v_names = list(pyro.get_param_store().keys())

for t in range(wandb.config.epochs):
    elbo = svi.step(PatientID, Weeks, FVC_obs)

    elbo_dict = {"ELBO": elbo}
    wandb.log(elbo_dict)
    if t % 100 == 0:
      v_params = {v: pyro.param(v).data.numpy() for v in v_names}
      wandb.log(v_params)
    

Here is a link to my notebook: https://github.com/kenleejr/variational_inference/blob/main/Pyro_Hierarchical_Regression.ipynb. Attached are images of the parameters during the optimization loop.



hi @kenleejr92 what exactly is your question? why do you expect svi to perform as well as hmc?

with regards to your custom guide, i think you generally want more flexible distributions for the scale parameters, e.g. LogNormal. what you have is likely too inflexible:
σ = pyro.sample("σ", dist.HalfNormal(global_σ))

for general svi tips and tricks see here.
for example, you might benefit from training longer and using a decaying learning rate (or e.g. a learning rate that is decreased at discrete steps).

I don’t expect VI to perform as well as HMC, but I was hoping to get at least something viable with VI. Even after training for very long (50k epochs) and with lower or decaying learning rates, the resulting coefficients of the hierarchical regression are way off. I was looking at this paper https://arxiv.org/pdf/2111.03144.pdf, which solves a couple hierarchical regression problems with VI pretty well.

@kenleejr92 don’t know what the dimensions of your problem are but you might try AutoBNAFNormal or AutoDAIS