Save MCMC results

Hi - I’m contributing to a library for bayesian analysis and visualization (arviz), and we are trying to add a function to convert Pyro objects into data that can be used there.

As part of that, I am working with the eight schools MCMC example, but can not figure out how to save the output: specifically the pyro.infer.mcmc.mcmc.MCMC object. I know elsewhere pyro.get_param_store() is recommended, but that is empty after running NUTS. Using pickle/dill/cloudpickle also fails for inscrutable reasons.

I would appreciate any suggestions!

The param store is only used by SVI where the result of inference has a succinct representation in terms of the parameters of the variational distribution. MCMC does not interact with the param store, and returns a TracePosterior object, which is simply a holder of samples from the posterior distribution over the latent variables. It is, however, surprising that this object is not picklable. What is the error that you get when you try to pickle it? It will be great if you could create an issue regarding the pickling issue, and I will take a look at it.

1 Like

Thanks for the quick response! This project might be helpful to pyro in providing diagnostics and plots. The currently open pull request allows this sort of API.

1 Like

Thanks for creating the issue, @colcarroll! It will be great to have a trace_plot utility for Pyro, so thanks for contributing this upstream to arviz! Also, let us also know if there is anything in the interface that should be reconsidered, i.e. other functionality or information that the TracePosterior class should expose, for diagnostics and plotting.

Once an object is converted to netcdf, it can use any function from arviz (it is planned that in a month or two, all plotting and diagnostics in PyMC3 will use arviz). It is probably a little heavily weighted towards HMC right now, but there is interest in getting more support for variational inference on the roadmap.

Here, for example, is a comparison of 4 chains from pystan, 4 from pymc3, and 1 from pyro, each with 500 draws.

2 Likes

Hello,
Just to mention that dill allows to save/load pyro.infer.mcmc.mcmc.MCMC objects.

with open('file.pkl', 'wb') as f:
	dill.dump(supervised_posterior, f)
with open('file.pkl', 'rb') as f:
	supervised_posterior = dill.load(f)
2 Likes

This piece of code using dill is indeed running, but in the case where jit_compile=True, it throws a “TypeError: can’t pickle torch._C.Function objects”. Is it possible in this case to save the object?

Below an example: the code runs without jit_compile=True but raises a TypeError with jit_compile=True.

import torch
import pyro.distributions as dist
import pyro
from pyro.infer.mcmc import MCMC, NUTS
import dill

assert pyro.__version__.startswith('1.1.0')

#Generating data
true_coefs = torch.tensor([1., 2., 3.])
data = torch.randn(2000, 3)
dim = 3
labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample()

def model(data):
    coefs_mean = torch.zeros(dim)
    coefs = pyro.sample('beta', dist.Normal(coefs_mean, torch.ones(3)))
    y = pyro.sample('y', dist.Bernoulli(logits=(coefs * data).sum(-1)), obs=labels)
    return y

#Running NUTS
kernel = NUTS(model, adapt_step_size=True, jit_compile=True)
mcmc = MCMC(kernel, num_samples=500, warmup_steps=300)
mcmc.run(data)

#Saving the posterior
with open('savemcmc.pkl', 'wb') as f:
    dill.dump(mcmc, f)

@vincentbt I used mcmc.sampler = None before saving and it worked for my purpose. Would it be enough for you?

@fehiepsi Adding mcmc.sampler = None does not solve the problem: I still get the same error. Could you show a code example? Thank you.

Sorry, it seems that you would need to set mcmc.kernel.potential_fn = None instead of mcmc.sampler = None. I just tested it with your code. :slight_smile:

1 Like

Thank you, it works perfectly

@fehiepsi What is the solution for numpyro? mcmc.sampler = None allowed me to save one of my models, but it failed for another model.

@vincentbt maybe you can try mcmc._cache = {}? It works in arviz.

1 Like

It worked, thanks!