I have an numpyro model that consumes too much memory on GPU to run MCMC for more than 500-1,000 samples at a time. I have been pickling the mcmc as well as samples every 500 steps, and then starting sampling following the pattern shown on the numpyro docs for mcmc.post_warmup_state. However, I recently ran into this issue:
with open('path_to_mcmc.pkl', 'rb') as f:
mcmc = pickle.load(f)
mcmc.post_warmup_state = mcmc.last_state
mcmc.run(
random.PRNGKey(0),
*model_args,
**model_kwargs,
)
I get this error:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In [23], line 1
----> 1 mcmc.run(
2 random.PRNGKey(0),
3 *model_args,
4 **model_kwargs,
5 )
File /opt/miniconda3/envs/refit_fvs/lib/python3.10/site-packages/numpyro/infer/mcmc.py:593, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs)
591 map_args = (rng_key, init_state, init_params)
592 if self.num_chains == 1:
--> 593 states_flat, last_state = partial_map_fn(map_args)
594 states = tree_map(lambda x: x[jnp.newaxis, ...], states_flat)
595 else:
File /opt/miniconda3/envs/refit_fvs/lib/python3.10/site-packages/numpyro/infer/mcmc.py:404, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields)
398 collection_size = self._collection_params["collection_size"]
399 collection_size = (
400 collection_size
401 if collection_size is None
402 else collection_size // self.thinning
403 )
--> 404 collect_vals = fori_collect(
405 lower_idx,
406 upper_idx,
407 sample_fn,
408 init_val,
409 transform=_collect_fn(collect_fields),
410 progbar=self.progress_bar,
411 return_last_val=True,
412 thinning=self.thinning,
413 collection_size=collection_size,
414 progbar_desc=partial(_get_progbar_desc_str, lower_idx, phase),
415 diagnostics_fn=diagnostics,
416 num_chains=self.num_chains if self.chain_method == "parallel" else 1,
417 )
418 states, last_val = collect_vals
419 # Get first argument of type `HMCState`
File /opt/miniconda3/envs/refit_fvs/lib/python3.10/site-packages/numpyro/util.py:358, in fori_collect(lower, upper, body_fun, init_val, transform, progbar, return_last_val, collection_size, thinning, **progbar_opts)
356 with tqdm.trange(upper) as t:
357 for i in t:
--> 358 vals = jit(_body_fn)(i, vals)
359 t.set_description(progbar_desc(i), refresh=False)
360 if diagnostics_fn:
[... skipping hidden 14 frame]
File /opt/miniconda3/envs/refit_fvs/lib/python3.10/site-packages/numpyro/util.py:323, in fori_collect.<locals>._body_fn(i, vals)
320 @cached_by(fori_collect, body_fun, transform)
321 def _body_fn(i, vals):
322 val, collection, start_idx, thinning = vals
--> 323 val = body_fun(val)
324 idx = (i - start_idx) // thinning
325 collection = cond(
326 idx >= 0,
327 collection,
(...)
330 identity,
331 )
File /opt/miniconda3/envs/refit_fvs/lib/python3.10/site-packages/numpyro/infer/mcmc.py:172, in _sample_fn_nojit_args(state, sampler, args, kwargs)
170 def _sample_fn_nojit_args(state, sampler, args, kwargs):
171 # state is a tuple of size 1 - containing HMCState
--> 172 return (sampler.sample(state[0], args, kwargs),)
File /opt/miniconda3/envs/refit_fvs/lib/python3.10/site-packages/numpyro/infer/hmc.py:771, in HMC.sample(self, state, model_args, model_kwargs)
761 def sample(self, state, model_args, model_kwargs):
762 """
763 Run HMC from the given :data:`~numpyro.infer.hmc.HMCState` and return the resulting
764 :data:`~numpyro.infer.hmc.HMCState`.
(...)
769 :return: Next `state` after running HMC.
770 """
--> 771 return self._sample_fn(state, model_args, model_kwargs)
TypeError: 'NoneType' object is not callable
Why is the sample function not being recognized here? When I inspect the nuts kernel, is shows None for the sample_fn_
attribute even though there is a sampler
attribute (<numpyro.infer.hmc.NUTS at 0x7f355ed6c940>
)