Hi - I’m contributing to a library for bayesian analysis and visualization (arviz), and we are trying to add a function to convert Pyro objects into data that can be used there.
As part of that, I am working with the eight schools MCMC example, but can not figure out how to save the output: specifically the pyro.infer.mcmc.mcmc.MCMC object. I know elsewhere pyro.get_param_store() is recommended, but that is empty after running NUTS. Using pickle/dill/cloudpickle also fails for inscrutable reasons.
The param store is only used by SVI where the result of inference has a succinct representation in terms of the parameters of the variational distribution. MCMC does not interact with the param store, and returns a TracePosterior object, which is simply a holder of samples from the posterior distribution over the latent variables. It is, however, surprising that this object is not picklable. What is the error that you get when you try to pickle it? It will be great if you could create an issue regarding the pickling issue, and I will take a look at it.
Thanks for the quick response! This project might be helpful to pyro in providing diagnostics and plots. The currently open pull request allows this sort of API.
Thanks for creating the issue, @colcarroll! It will be great to have a trace_plot utility for Pyro, so thanks for contributing this upstream to arviz! Also, let us also know if there is anything in the interface that should be reconsidered, i.e. other functionality or information that the TracePosterior class should expose, for diagnostics and plotting.
Once an object is converted to netcdf, it can use any function from arviz (it is planned that in a month or two, all plotting and diagnostics in PyMC3 will use arviz). It is probably a little heavily weighted towards HMC right now, but there is interest in getting more support for variational inference on the roadmap.
Here, for example, is a comparison of 4 chains from pystan, 4 from pymc3, and 1 from pyro, each with 500 draws.
This piece of code using dill is indeed running, but in the case where jit_compile=True, it throws a “TypeError: can’t pickle torch._C.Function objects”. Is it possible in this case to save the object?
Below an example: the code runs without jit_compile=True but raises a TypeError with jit_compile=True.
import torch
import pyro.distributions as dist
import pyro
from pyro.infer.mcmc import MCMC, NUTS
import dill
assert pyro.__version__.startswith('1.1.0')
#Generating data
true_coefs = torch.tensor([1., 2., 3.])
data = torch.randn(2000, 3)
dim = 3
labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample()
def model(data):
coefs_mean = torch.zeros(dim)
coefs = pyro.sample('beta', dist.Normal(coefs_mean, torch.ones(3)))
y = pyro.sample('y', dist.Bernoulli(logits=(coefs * data).sum(-1)), obs=labels)
return y
#Running NUTS
kernel = NUTS(model, adapt_step_size=True, jit_compile=True)
mcmc = MCMC(kernel, num_samples=500, warmup_steps=300)
mcmc.run(data)
#Saving the posterior
with open('savemcmc.pkl', 'wb') as f:
dill.dump(mcmc, f)