Simple MCMC on Gamma/Beta providing wrong estimates

Hi everyone,

I just started with probabilistic programming (and Pyro) and trying to implement a very simple generative model—a Beta likelihood whose two parameters have a Gamma prior on them:

$$\alpha\sim Gamma(1, 1)$$
$$\beta\sim Gamma(1, 1)$$
$$x\sim Beta(\alpha, \beta)$$

and then running MCMC to perform inference on it, I am not able to retrieve good estimates for \alpha and \beta…and I am pretty sure I am messing with something : )

Here is the MWE for my case:

import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as ss
import torch
from torch.autograd import Variable
import pyro
import pyro.distributions as dist
import pyro.infer.mcmc as mcmc
from pyro.infer.abstract_infer import EmpiricalMarginal


a = 5
b = 1
n_samples = 1000
beta_data = ss.beta(a=a, b=b).rvs(n_samples).astype(np.float32)
print(beta_data.shape)
plt.hist(beta_data, bins=100)
plt.show()

def beta_model(data):
    alpha_0 = torch.ones(1)
    beta_0 = torch.ones(1)
    alpha_prior = pyro.sample('alpha', dist.Gamma(concentration=alpha_0, rate=beta_0))
    beta_prior = pyro.sample('beta', dist.Gamma(concentration=alpha_0, rate=beta_0))
    x = pyro.sample('x', dist.Beta(concentration1=alpha_prior, concentration0=beta_prior), obs=data)
    return x

# hmc_kernel = mcmc.HMC(beta_model, step_size=0.0855, num_steps=4)
nuts_kernel = mcmc.NUTS(beta_model, adapt_step_size=True)
mcmc_run = mcmc.MCMC(nuts_kernel, num_samples=1000, warmup_steps=10000).run(torch.from_numpy(beta_data))
beta_posterior = EmpiricalMarginal(mcmc_run, 'beta')
print('beta', beta_posterior.mean) 
alpha_posterior = EmpiricalMarginal(mcmc_run, 'alpha')
print('alpha', alpha_posterior.mean) 

Instead of retrieving the true values of 5 and 1 I am getting something like:

beta tensor([ 0.5838]) 
alpha tensor([ 2.5927]) 

What am I doing wrong?

The acceptance probability is 0, so we are not generating even a single uncorrelated sample from the posterior. You can switch on the debugging flag (see below) to see if there are any problems in sampling.

This particular example seems to be quite sensitive to step size. It works fine for HMC (code below), but it seems to have problems with NUTS, or with step size adaptation. cc. @fehiepsi. We will need to take a deeper look to see what’s happening here. Thanks for reporting! Could you create the issue on GitHub and reference this post?

import logging

import torch

import pyro
import pyro.distributions as dist
import pyro.infer.mcmc as mcmc
from pyro.infer.abstract_infer import EmpiricalMarginal

logging.basicConfig(format='%(message)s', level=logging.INFO)


a = 5
b = 1
n_samples = 1000
beta_data = dist.Beta(a, b).sample(torch.Size((n_samples,)))
print(beta_data.shape)


def beta_model(data):
    alpha_prior = pyro.sample('alpha', dist.Gamma(1., 1.))
    beta_prior = pyro.sample('beta', dist.Gamma(1., 1.))
    x = pyro.sample('x', dist.Beta(alpha_prior, beta_prior), obs=data)
    return x


kernel = mcmc.HMC(beta_model, step_size=0.001, num_steps=8)
mcmc_run = mcmc.MCMC(kernel, num_samples=1000, warmup_steps=1000).run(beta_data)
beta_posterior = EmpiricalMarginal(mcmc_run, 'beta')
print('beta', beta_posterior.mean)
alpha_posterior = EmpiricalMarginal(mcmc_run, 'alpha')
print('alpha', alpha_posterior.mean)
1 Like

Thanks Neeraj,

I totally missed how to inspect and debug sampling!
I do not see why this particular example is such sensitive to stepsize. As a sanity check, it works fine with default parameters in the PyMC3 implementation of HMC and NUTS.

I will open the issue on github as you suggest

Cheers

As a sanity check, it works fine with default parameters in the PyMC3 implementation of HMC and NUTS.

PyMC3 does mass matrix adaptation - we have a contributor who is currently working on this for Pyro (see Mass Matrix Adaptation for HMC and NUTS · Issue #1137 · pyro-ppl/pyro · GitHub). Currently, we are just sampling momenta from N(0, I). It would be interesting to see how the results look like without mass matrix adaptation in PyMC3. If it suffers from similar issues, then that would point to one source for discrepancy in the results that you see.

1 Like