Error loading and running a pickled MCMC object

I have an numpyro model that consumes too much memory on GPU to run MCMC for more than 500-1,000 samples at a time. I have been pickling the mcmc as well as samples every 500 steps, and then starting sampling following the pattern shown on the numpyro docs for mcmc.post_warmup_state. However, I recently ran into this issue:

with open('path_to_mcmc.pkl', 'rb') as f:
    mcmc = pickle.load(f)

mcmc.post_warmup_state = mcmc.last_state

mcmc.run(
    random.PRNGKey(0),
    *model_args,
    **model_kwargs,
)

I get this error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In [23], line 1
----> 1 mcmc.run(
      2         random.PRNGKey(0),
      3         *model_args,
      4         **model_kwargs,
      5     )

File /opt/miniconda3/envs/refit_fvs/lib/python3.10/site-packages/numpyro/infer/mcmc.py:593, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs)
    591 map_args = (rng_key, init_state, init_params)
    592 if self.num_chains == 1:
--> 593     states_flat, last_state = partial_map_fn(map_args)
    594     states = tree_map(lambda x: x[jnp.newaxis, ...], states_flat)
    595 else:

File /opt/miniconda3/envs/refit_fvs/lib/python3.10/site-packages/numpyro/infer/mcmc.py:404, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields)
    398 collection_size = self._collection_params["collection_size"]
    399 collection_size = (
    400     collection_size
    401     if collection_size is None
    402     else collection_size // self.thinning
    403 )
--> 404 collect_vals = fori_collect(
    405     lower_idx,
    406     upper_idx,
    407     sample_fn,
    408     init_val,
    409     transform=_collect_fn(collect_fields),
    410     progbar=self.progress_bar,
    411     return_last_val=True,
    412     thinning=self.thinning,
    413     collection_size=collection_size,
    414     progbar_desc=partial(_get_progbar_desc_str, lower_idx, phase),
    415     diagnostics_fn=diagnostics,
    416     num_chains=self.num_chains if self.chain_method == "parallel" else 1,
    417 )
    418 states, last_val = collect_vals
    419 # Get first argument of type `HMCState`

File /opt/miniconda3/envs/refit_fvs/lib/python3.10/site-packages/numpyro/util.py:358, in fori_collect(lower, upper, body_fun, init_val, transform, progbar, return_last_val, collection_size, thinning, **progbar_opts)
    356 with tqdm.trange(upper) as t:
    357     for i in t:
--> 358         vals = jit(_body_fn)(i, vals)
    359         t.set_description(progbar_desc(i), refresh=False)
    360         if diagnostics_fn:

    [... skipping hidden 14 frame]

File /opt/miniconda3/envs/refit_fvs/lib/python3.10/site-packages/numpyro/util.py:323, in fori_collect.<locals>._body_fn(i, vals)
    320 @cached_by(fori_collect, body_fun, transform)
    321 def _body_fn(i, vals):
    322     val, collection, start_idx, thinning = vals
--> 323     val = body_fun(val)
    324     idx = (i - start_idx) // thinning
    325     collection = cond(
    326         idx >= 0,
    327         collection,
   (...)
    330         identity,
    331     )

File /opt/miniconda3/envs/refit_fvs/lib/python3.10/site-packages/numpyro/infer/mcmc.py:172, in _sample_fn_nojit_args(state, sampler, args, kwargs)
    170 def _sample_fn_nojit_args(state, sampler, args, kwargs):
    171     # state is a tuple of size 1 - containing HMCState
--> 172     return (sampler.sample(state[0], args, kwargs),)

File /opt/miniconda3/envs/refit_fvs/lib/python3.10/site-packages/numpyro/infer/hmc.py:771, in HMC.sample(self, state, model_args, model_kwargs)
    761 def sample(self, state, model_args, model_kwargs):
    762     """
    763     Run HMC from the given :data:`~numpyro.infer.hmc.HMCState` and return the resulting
    764     :data:`~numpyro.infer.hmc.HMCState`.
   (...)
    769     :return: Next `state` after running HMC.
    770     """
--> 771     return self._sample_fn(state, model_args, model_kwargs)

TypeError: 'NoneType' object is not callable

Why is the sample function not being recognized here? When I inspect the nuts kernel, is shows None for the sample_fn_ attribute even though there is a sampler attribute (<numpyro.infer.hmc.NUTS at 0x7f355ed6c940>)

Currently, when pickled, we delete the sample_fn because it is not pickable. I guess we can add some logic here to take care of that. Something like

if init_state is None or getattr(self.sampler, '_sample_fn', None):
    new_init_state = ...
    init_state = new_init_state if init_state is None else init_state

Could you try if it works?

This is fixed in Allow pickled mcmc object to run post warmup phase by fehiepsi · Pull Request #1558 · pyro-ppl/numpyro · GitHub