@fehiepsi , I just tried to initialize with
mcmc_stocks.post_warmup_state = mcmc_stocks.last_state
and then got the following strange error. Is there a way around this or did I just lose my samples? I utilized 10 warmup and 40 samples from 4 parallel chains .
I really do need this function because obtaining 100 samples can take hours and sometimes the kernel suddenly resets after hours of running. This is very frustrating. I’m trying to get to 1000 samples post warmup so I’m thinking about running 50 samples repeatedly in a for-loop and then store intermediate values using jax.device_get(mcmc_stocks.post_warmup_state)
as you suggested.
---------------------------------------------------------------------------
UnexpectedTracerError Traceback (most recent call last)
/var/folders/58/rcr3xzjn6bscyr8llj4t161m0000gn/T/ipykernel_81238/642589350.py in <module>
1 start = time.time()
----> 2 mcmc_stocks.run(mcmc_stocks.post_warmup_state.rng_key, data_bl, data_t, K, ind_mat,
3 ts, num_params, constants, aux_predictors_mat,
4 age_per_month, prior)
5 print("NUTS took ", round((time.time()-start)/60, 2), " minutes.") # or mcmc.run(random.PRNGKey(1))
~/Anaconda/anaconda3/envs/spyder-env/lib/python3.9/site-packages/numpyro/infer/mcmc.py in run(self, rng_key, extra_fields, init_params, *args, **kwargs)
576 states, last_state = _laxmap(partial_map_fn, map_args)
577 elif self.chain_method == "parallel":
--> 578 states, last_state = pmap(partial_map_fn)(map_args)
579 else:
580 assert self.chain_method == "vectorized"
[... skipping hidden 13 frame]
~/Anaconda/anaconda3/envs/spyder-env/lib/python3.9/site-packages/numpyro/infer/mcmc.py in _single_chain_mcmc(self, init, args, kwargs, collect_fields)
381 else collection_size // self.thinning
382 )
--> 383 collect_vals = fori_collect(
384 lower_idx,
385 upper_idx,
~/Anaconda/anaconda3/envs/spyder-env/lib/python3.9/site-packages/numpyro/util.py in fori_collect(lower, upper, body_fun, init_val, transform, progbar, return_last_val, collection_size, thinning, **progbar_opts)
337 progress_bar_fori_loop = progress_bar_factory(upper, num_chains)
338 _body_fn_pbar = progress_bar_fori_loop(_body_fn)
--> 339 last_val, collection, _, _ = fori_loop(
340 0, upper, _body_fn_pbar, (init_val, collection, start_idx, thinning)
341 )
~/Anaconda/anaconda3/envs/spyder-env/lib/python3.9/site-packages/numpyro/util.py in fori_loop(lower, upper, body_fun, init_val)
137 return val
138 else:
--> 139 return lax.fori_loop(lower, upper, body_fun, init_val)
140
141
[... skipping hidden 15 frame]
~/Anaconda/anaconda3/envs/spyder-env/lib/python3.9/site-packages/numpyro/util.py in wrapper_progress_bar(i, vals)
244
245 def wrapper_progress_bar(i, vals):
--> 246 result = func(i, vals)
247 _update_progress_bar(i + 1)
248 return result
~/Anaconda/anaconda3/envs/spyder-env/lib/python3.9/site-packages/numpyro/util.py in _body_fn(i, vals)
318 def _body_fn(i, vals):
319 val, collection, start_idx, thinning = vals
--> 320 val = body_fun(val)
321 idx = (i - start_idx) // thinning
322 collection = cond(
~/Anaconda/anaconda3/envs/spyder-env/lib/python3.9/site-packages/numpyro/infer/mcmc.py in _sample_fn_nojit_args(state, sampler, args, kwargs)
172 def _sample_fn_nojit_args(state, sampler, args, kwargs):
173 # state is a tuple of size 1 - containing HMCState
--> 174 return (sampler.sample(state[0], args, kwargs),)
175
176
~/Anaconda/anaconda3/envs/spyder-env/lib/python3.9/site-packages/numpyro/infer/hmc.py in sample(self, state, model_args, model_kwargs)
758 :return: Next `state` after running HMC.
759 """
--> 760 return self._sample_fn(state, model_args, model_kwargs)
761
762 def __getstate__(self):
~/Anaconda/anaconda3/envs/spyder-env/lib/python3.9/site-packages/numpyro/infer/hmc.py in sample_kernel(hmc_state, model_args, model_kwargs)
468 )
469 # not update adapt_state after warmup phase
--> 470 adapt_state = cond(
471 hmc_state.i < wa_steps,
472 (hmc_state.i, accept_prob, vv_state, hmc_state.adapt_state),
~/Anaconda/anaconda3/envs/spyder-env/lib/python3.9/site-packages/numpyro/util.py in cond(pred, true_operand, true_fun, false_operand, false_fun)
117 return false_fun(false_operand)
118 else:
--> 119 return lax.cond(pred, true_operand, true_fun, false_operand, false_fun)
120
121
[... skipping hidden 16 frame]
~/Anaconda/anaconda3/envs/spyder-env/lib/python3.9/site-packages/numpyro/infer/hmc.py in <lambda>(args)
471 hmc_state.i < wa_steps,
472 (hmc_state.i, accept_prob, vv_state, hmc_state.adapt_state),
--> 473 lambda args: wa_update(*args),
474 hmc_state.adapt_state,
475 identity,
~/Anaconda/anaconda3/envs/spyder-env/lib/python3.9/site-packages/numpyro/infer/hmc_util.py in update_fn(t, accept_prob, z_info, state)
679 )
680
--> 681 t_at_window_end = t == adaptation_schedule[window_idx, 1]
682 window_idx = jnp.where(t_at_window_end, window_idx + 1, window_idx)
683 state = HMCAdaptState(
[... skipping hidden 1 frame]
~/Anaconda/anaconda3/envs/spyder-env/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py in _rewriting_take(arr, idx, indices_are_sorted, unique_indices, mode, fill_value)
5642 arr = asarray(arr)
5643 treedef, static_idx, dynamic_idx = _split_index_for_jit(idx, arr.shape)
-> 5644 return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
5645 unique_indices, mode, fill_value)
5646
~/Anaconda/anaconda3/envs/spyder-env/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py in _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted, unique_indices, mode, fill_value)
5669 # We avoid generating a gather when indexer.gather_indices.size is empty.
5670 if not core.is_empty_shape(indexer.gather_indices.shape):
-> 5671 y = lax.gather(
5672 y, indexer.gather_indices, indexer.dnums, indexer.gather_slice_shape,
5673 unique_indices=unique_indices or indexer.unique_indices,
[... skipping hidden 4 frame]
~/Anaconda/anaconda3/envs/spyder-env/lib/python3.9/site-packages/jax/interpreters/partial_eval.py in _assert_live(self)
1170 def _assert_live(self) -> None:
1171 if not self._trace.main.jaxpr_stack: # type: ignore
-> 1172 raise core.escaped_tracer_error(self, None)
1173
1174 class JaxprStackFrame:
UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with shape (1, 2) and dtype int32 to escape.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was _single_chain_mcmc at /Users/jeroenuleman/Anaconda/anaconda3/envs/spyder-env/lib/python3.9/site-packages/numpyro/infer/mcmc.py:357 traced for pmap.
------------------------------
The leaked intermediate value was created on line /Users/.../Anaconda/anaconda3/envs/spyder-env/lib/python3.9/site-packages/numpyro/infer/hmc_util.py:544 (warmup_adapter).
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
/Users/.../Anaconda/anaconda3/envs/spyder-env/lib/python3.9/site-packages/numpyro/infer/mcmc.py:360 (_single_chain_mcmc)
/Users/.../Anaconda/anaconda3/envs/spyder-env/lib/python3.9/site-packages/numpyro/infer/hmc.py:735 (init)
/Users/.../Anaconda/anaconda3/envs/spyder-env/lib/python3.9/site-packages/numpyro/infer/hmc.py:716 (<lambda>)
/Users/.../Anaconda/anaconda3/envs/spyder-env/lib/python3.9/site-packages/numpyro/infer/hmc.py:300 (init_kernel)
/Users/.../Anaconda/anaconda3/envs/spyder-env/lib/python3.9/site-packages/numpyro/infer/hmc_util.py:544 (warmup_adapter)
------------------------------
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError