Can't reproduce NUTS results with SVI

Hi everyone,

I am trying to get a better intuition on how to use SVI. I simulated a small dataset with 3 different groups (categories) and I am trying to recover the true parameter values using Pyro, for learning purposes. I was able to recover the correct parameters values using MCMC, but I am having a lot of trouble while performing the inference using SVI. I suspect that I am messing something up during the specification of the guide. I would be very happy if I could get some help with this. Here is the code:

import numpy as np
import pandas as pd
import pyro
import torch
import pyro.poutine as poutine
import arviz as az
from pyro import distributions as dist
from pyro.distributions import constraints
from pyro.infer import SVI, Trace_ELBO, MCMC, NUTS, Predictive
from pyro.optim import Adam
from matplotlib import pyplot as plt

The data from each group will come from a lognormal distribution with the parameters defined as shown below

g1 = [{'group': 0, 'target': np.random.lognormal(5., 1.)} for _ in range(100)]
g2 = [{'group': 1, 'target': np.random.lognormal(2.0, 1.)} for _ in range(100)]
g3 = [{'group': 2, 'target': np.random.lognormal(-1., 1.)} for _ in range(100)]

records = []
records.extend(g1)
records.extend(g2)
records.extend(g3)

d = pd.DataFrame.from_records(records)

Defining a simple varying intercepts model

def model(group_id, target_obs=None):
    
    n_groups = len(np.unique(group_id))    
    
    # Parameters of distributions
    loc_a = pyro.param('loc_a', torch.zeros(n_groups))
    scale_a = pyro.param('loc_scale', torch.tensor(1.))
    sigma_loc = pyro.param('sigma_scale', torch.tensor(2.0))
    
    # Distributions
    a = pyro.sample('a', dist.Normal(loc_a, scale_a))
    sigma = pyro.sample('sigma', dist.HalfNormal(sigma_loc))
    
    # Model for the mean
    mu = a[group_id]
    pyro.sample('obs', dist.LogNormal(mu, sigma), obs=target_obs)

Starting with MCMC

train_data = {
    'group_id': torch.tensor(d['group'].values),
    'target_obs': torch.tensor(d['target'].values)
}
m = MCMC(NUTS(model), 1000, 500, num_chains=1)
m.run(**train_data)

m.get_samples()['a'].mean(axis=0)

The output of the posterior for the ‘a’ parameters is the following tensor, as expected:

tensor([ 4.8418, 1.9719, -0.8970], dtype=torch.float64)

Trying to reproduce the same results with SVI. Defining the guide for approximate inference

def guide(group_id, target_obs=None):
    
    n_groups = len(np.unique(group_id))
    
    # Initializing all params
    loc_a = torch.zeros(n_groups)
    loc_scale = torch.tensor(1.)
    sigma_scale = torch.tensor(2.0)
    
    # Registering learnable params in the pyro param store
    loc_a_param = pyro.param('loc_a', loc_a)
    loc_scale_param = pyro.param('loc_scale', loc_scale, constraint=constraints.positive)
    sigma_scale_param = pyro.param('sigma_scale', sigma_scale, constraint=constraints.positive)
    
    # guide distributions
    pyro.sample('sigma', dist.HalfNormal(sigma_scale_param))
    pyro.sample('a', dist.Normal(loc_a_param, loc_scale_param))

Learning params with SVI

optim = Adam({'lr': 0.001})
elbo = Trace_ELBO()
svi = SVI(model, guide, optim, loss=elbo)

pyro.clear_param_store()
num_epochs = 2000
losses = []
for j in range(num_epochs):
    loss = svi.step(**train_data)
    losses.append(loss)
    if j % 100 == 0:
        print(loss)

Sampling from posterior and checking learned param values

p = Predictive(model, guide=guide, num_samples=800)
post_samples = p(group_id=np.unique(train_data['group_id']))
post_samples['a'].mean(axis=0)

The output of ‘post_samples[‘a’].mean(axis=0)’ is:

tensor([ 0.1140, 0.1112, -0.1202], grad_fn=<MeanBackward1>).

It seems that I am missing some parameter configuration in the guide or just specifying it wrong. I tried getting some intuition about this by looking at some examples but unfortunately I was not able to get an idea of why the SVI inference isn’t working. It would be very nice if I could get some help here and figure out the reason :slight_smile:

are you using the same model for SVI and MCMC? because note that param statements encode learnable parameters. SVI will move those around, MCMC will not.

Yes, I am. Shouldn’t this be negligible since I am just trying to obtain the same result using both MCMC and SVI (also using SVI after MCMC)?

well you’re allowing the prior to change in one case and not in the other. so i don’t know why you’d necessarily expect comparable results.

Then I think I don’t get how to use SVI. I also tried build two models: one for MCMC without the pyro.param calls and another one for the SVI step with the pyro.param calls. Still, I couldn’t get the desired results. Could you provide a reference or an example where this concept is shown?

I am relatively new to Pyro, coming from PyMC, so I apologize if my question is somehow stupid or too basic.

i would suggest that for both MCMC and SVI the model should not contain any params:

    loc_a = torch.zeros(n_groups)

then the corresponding priors will be treated as fixed distributions, which is presumably what you want.

Ok, I will try that. Thank you very much!