Hi everyone,
I am trying to get a better intuition on how to use SVI. I simulated a small dataset with 3 different groups (categories) and I am trying to recover the true parameter values using Pyro, for learning purposes. I was able to recover the correct parameters values using MCMC, but I am having a lot of trouble while performing the inference using SVI. I suspect that I am messing something up during the specification of the guide. I would be very happy if I could get some help with this. Here is the code:
import numpy as np
import pandas as pd
import pyro
import torch
import pyro.poutine as poutine
import arviz as az
from pyro import distributions as dist
from pyro.distributions import constraints
from pyro.infer import SVI, Trace_ELBO, MCMC, NUTS, Predictive
from pyro.optim import Adam
from matplotlib import pyplot as plt
The data from each group will come from a lognormal distribution with the parameters defined as shown below
g1 = [{'group': 0, 'target': np.random.lognormal(5., 1.)} for _ in range(100)]
g2 = [{'group': 1, 'target': np.random.lognormal(2.0, 1.)} for _ in range(100)]
g3 = [{'group': 2, 'target': np.random.lognormal(-1., 1.)} for _ in range(100)]
records = []
records.extend(g1)
records.extend(g2)
records.extend(g3)
d = pd.DataFrame.from_records(records)
Defining a simple varying intercepts model
def model(group_id, target_obs=None):
n_groups = len(np.unique(group_id))
# Parameters of distributions
loc_a = pyro.param('loc_a', torch.zeros(n_groups))
scale_a = pyro.param('loc_scale', torch.tensor(1.))
sigma_loc = pyro.param('sigma_scale', torch.tensor(2.0))
# Distributions
a = pyro.sample('a', dist.Normal(loc_a, scale_a))
sigma = pyro.sample('sigma', dist.HalfNormal(sigma_loc))
# Model for the mean
mu = a[group_id]
pyro.sample('obs', dist.LogNormal(mu, sigma), obs=target_obs)
Starting with MCMC
train_data = {
'group_id': torch.tensor(d['group'].values),
'target_obs': torch.tensor(d['target'].values)
}
m = MCMC(NUTS(model), 1000, 500, num_chains=1)
m.run(**train_data)
m.get_samples()['a'].mean(axis=0)
The output of the posterior for the ‘a’ parameters is the following tensor, as expected:
tensor([ 4.8418, 1.9719, -0.8970], dtype=torch.float64)
Trying to reproduce the same results with SVI. Defining the guide for approximate inference
def guide(group_id, target_obs=None):
n_groups = len(np.unique(group_id))
# Initializing all params
loc_a = torch.zeros(n_groups)
loc_scale = torch.tensor(1.)
sigma_scale = torch.tensor(2.0)
# Registering learnable params in the pyro param store
loc_a_param = pyro.param('loc_a', loc_a)
loc_scale_param = pyro.param('loc_scale', loc_scale, constraint=constraints.positive)
sigma_scale_param = pyro.param('sigma_scale', sigma_scale, constraint=constraints.positive)
# guide distributions
pyro.sample('sigma', dist.HalfNormal(sigma_scale_param))
pyro.sample('a', dist.Normal(loc_a_param, loc_scale_param))
Learning params with SVI
optim = Adam({'lr': 0.001})
elbo = Trace_ELBO()
svi = SVI(model, guide, optim, loss=elbo)
pyro.clear_param_store()
num_epochs = 2000
losses = []
for j in range(num_epochs):
loss = svi.step(**train_data)
losses.append(loss)
if j % 100 == 0:
print(loss)
Sampling from posterior and checking learned param values
p = Predictive(model, guide=guide, num_samples=800)
post_samples = p(group_id=np.unique(train_data['group_id']))
post_samples['a'].mean(axis=0)
The output of ‘post_samples[‘a’].mean(axis=0)’ is:
tensor([ 0.1140, 0.1112, -0.1202], grad_fn=<MeanBackward1>)
.
It seems that I am missing some parameter configuration in the guide or just specifying it wrong. I tried getting some intuition about this by looking at some examples but unfortunately I was not able to get an idea of why the SVI inference isn’t working. It would be very nice if I could get some help here and figure out the reason