Continue sampling from previous (saved) state

I am running MCMC in with the NUTS sampler. I cannot save all the samples in memory. Now I want to sample in memory as a buffer, write the samples to disk and continue sampling from the last state.

Is this possible?

Hi @ritchie46, it would be easier to do it using dev version in the following way

import tqdm
import torch

import pyro
import pyro.distributions as dist
from pyro.infer.mcmc import NUTS
from pyro.infer.mcmc.util import initialize_model

dim = 3
data = torch.randn(2000, dim)
true_coefs = torch.arange(1., dim + 1.)
labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample()

def model(data):
    coefs_mean = torch.zeros(dim)
    coefs = pyro.sample('beta', dist.Normal(coefs_mean, torch.ones(dim)))
    y = pyro.sample('y', dist.Bernoulli(logits=(coefs * data).sum(-1)), obs=labels)
    return y

init_params, potential_fn, transforms, _ = initialize_model(model, (data,), {},
                                                            jit_compile=True, skip_jit_warnings=True)
hmc_kernel = NUTS(model=None, potential_fn=potential_fn, transforms=transforms)
hmc_kernel.initial_params = init_params
hmc_kernel.setup(warmup_steps=1000)  # initialize warmup adapter
collection = []
for i in tqdm.trange(2000):
    z = hmc_kernel.sample(None)  # sample unconstrained values
    if i >= 1000:
        # transform unconstrained values to constrained values
        for name, transform in transforms.items():
            z[name] = transform.inv(z[name])
        collection.append(z)  # or save collection to disk