Store intermediate samples

Hi everyone, I’m dealing with very slow NUTS/MCMC sampling of a large system of differential algebraic equations. Even if I just run one chain, the progress bar often initially tells me the process will take 2 hours, but then becomes 13 hours later on (e.g., at 15% or so) and I don’t really trust the progress bar anymore.

I’m wondering what I can do to make the process more feasible because I’m not always able to let the notebook run indefinitely.

Right now I’m using Jupyter lab, is that recommended for long runs like this or would it be better to run from python script in terminal or Spyder or so? (MacOs)

I was also wondering if it’s possible to store intermediate samples during the sampling process (or after, but then I would run for fewer samples), so that I can interrupt the process and start again later?

Finally, is this unreliability of the progress bar a known issue, or is something going wrong on my side?

the issue isn’t the progress bar. it’s the fact that NUTS is adaptive. at any given point in the progress bar it makes a computation along the lines of “i’m at iteration 100; on averaged i’ve used 128 gradient steps per iteration; consequently my ETA is XXX”. however, during adaptation (as the mass matrix and step size changes) and during inference (as different regions of the posterior are explored) the number of gradient steps per iteration will adapt. consequently you can’t expect a perfect prediction on completion time, especially before you’ve completed warm-up.

Thanks for the explanation!

Do you have any idea whether I should put more trust in the optimistic (early) or pessimistic (later) result?

And would it be possible to store intermediate results or somehow access the samples from a chain that I was running but is taking too long to complete? It’s really frustrating when I’ve run 200 samples but the other 800 are taking too long to complete. Seems to me that I could still use the 200 samples and mix them with another 800 that are sampled later on? (Assuming in both cases the chains were sufficiently warmed up?)

in my experience the completion estimate that you get after warm-up finishes is pretty accurate.

because of the delicate nature of the adaptation algorithms i think it’s probably a bit tricky to do start-stop-restart-like workflows but other forum posts or @fehiepsi may be able to point you towards possible solutions

Yeah, you can search post_warmup_state for run-stop-continue workflow.

Thanks! That’s very helpful.

I got a bit confused about this part of the documentation
mcmc.post_warmup_state = mcmc.last_state

The ‘’ post_warmup_state ‘’ is the first state after the warm-up samples right? Or is it simply the last state that was sampled. So if I do 100 warmup runs and 100 samples (200 total), does post_warmup_state then give me the state at 100 runs, or at 200 runs?

And is this state simply equal to the set of parameters sampled at that particular point? In that case, is this initializing with ‘‘post_warmup_state’’ the same as simply storing the last sampled parameter values and then initializing the next sequence with these (i.e., using init_to_values()).

Alternatively, If the post_warmup_state contains more information than that, is there a convenient way of storing this state to my harddrive so that I can load and use it later?

Thanks again!

post_warmup_state is a state at some point after the warmup phase, could be the first one or the 1000th one, e.g. if you set mcmc.post_warmup_state = mcmc.last_state then it will be the last state of your chain. The chain will start at that post warmup state. It contains more information than just samples, like step_size, mass matrix, grad, potential energy,… you can use jax.device_get(state) to convert device array to numpy array and store it.

1 Like

Thanks so much! Ill try to work with that, I think it can save me a lot of time to store the warmed up state.

@martinjankowiak, just to quickly come back to your explanation about progress bar. What would be a reasonable speed to actually expect from NUTS?

Does it evaluate my likelihood function once for each parameter (for autodiff I guess?) for each iteration? I’m surprised that it’s so slow because to evaluate my likelihood function should only take a few seconds, so even if it evaluated it once for each parameter (say that takes a minute, or even two minutes), then 2 hours for 30 samples is too long.

Also, wouldn’t that mean that when dealing with large number of parameters it’s actually more efficient to use a gradient-free method like metropolis, even if you do need many more samples in total?

the progress bar reports things like “XXX steps of size YYY”

this tell you how many gradient steps per iteration it’s doing. for hard problems this will tend to max out at 2^{max_tree_depth}. max_tree_depth defaults to 10. so a hard problem may require millions of gradient evaluations to get a few thousand samples.

no, gradient-free methods will almost always be much less efficient, except maybe in a small handful of dimensions

1 Like

Thanks that’s helpful.

I just had a run that took 109.28 minutes (1:49:15 h).

It reports:
50/50 131.11s/it, 255 steps of size 3.02e-02.

So 131.11s/it is clearly the amount of seconds per ‘sample’, and if I understand correctly it takes 255 gradient steps in this time. Does that mean it evaluated my likelihood function 255*50=12,750 times? That would mean it takes about half a second. Sounds reasonable to me.

yes that’s right

@fehiepsi , I just tried to initialize with
mcmc_stocks.post_warmup_state = mcmc_stocks.last_state

and then got the following strange error. Is there a way around this or did I just lose my samples? I utilized 10 warmup and 40 samples from 4 parallel chains .

I really do need this function because obtaining 100 samples can take hours and sometimes the kernel suddenly resets after hours of running. This is very frustrating. I’m trying to get to 1000 samples post warmup so I’m thinking about running 50 samples repeatedly in a for-loop and then store intermediate values using jax.device_get(mcmc_stocks.post_warmup_state) as you suggested.

UnexpectedTracerError                     Traceback (most recent call last)
/var/folders/58/rcr3xzjn6bscyr8llj4t161m0000gn/T/ipykernel_81238/ in <module>
      1 start = time.time()
----> 2, data_bl, data_t, K, ind_mat,
      3                             ts, num_params, constants, aux_predictors_mat,
      4                             age_per_month, prior)
      5 print("NUTS took ", round((time.time()-start)/60, 2), " minutes.")  # or

~/Anaconda/anaconda3/envs/spyder-env/lib/python3.9/site-packages/numpyro/infer/ in run(self, rng_key, extra_fields, init_params, *args, **kwargs)
    576                 states, last_state = _laxmap(partial_map_fn, map_args)
    577             elif self.chain_method == "parallel":
--> 578                 states, last_state = pmap(partial_map_fn)(map_args)
    579             else:
    580                 assert self.chain_method == "vectorized"

    [... skipping hidden 13 frame]

~/Anaconda/anaconda3/envs/spyder-env/lib/python3.9/site-packages/numpyro/infer/ in _single_chain_mcmc(self, init, args, kwargs, collect_fields)
    381             else collection_size // self.thinning
    382         )
--> 383         collect_vals = fori_collect(
    384             lower_idx,
    385             upper_idx,

~/Anaconda/anaconda3/envs/spyder-env/lib/python3.9/site-packages/numpyro/ in fori_collect(lower, upper, body_fun, init_val, transform, progbar, return_last_val, collection_size, thinning, **progbar_opts)
    337         progress_bar_fori_loop = progress_bar_factory(upper, num_chains)
    338         _body_fn_pbar = progress_bar_fori_loop(_body_fn)
--> 339         last_val, collection, _, _ = fori_loop(
    340             0, upper, _body_fn_pbar, (init_val, collection, start_idx, thinning)
    341         )

~/Anaconda/anaconda3/envs/spyder-env/lib/python3.9/site-packages/numpyro/ in fori_loop(lower, upper, body_fun, init_val)
    137         return val
    138     else:
--> 139         return lax.fori_loop(lower, upper, body_fun, init_val)

    [... skipping hidden 15 frame]

~/Anaconda/anaconda3/envs/spyder-env/lib/python3.9/site-packages/numpyro/ in wrapper_progress_bar(i, vals)
    245         def wrapper_progress_bar(i, vals):
--> 246             result = func(i, vals)
    247             _update_progress_bar(i + 1)
    248             return result

~/Anaconda/anaconda3/envs/spyder-env/lib/python3.9/site-packages/numpyro/ in _body_fn(i, vals)
    318     def _body_fn(i, vals):
    319         val, collection, start_idx, thinning = vals
--> 320         val = body_fun(val)
    321         idx = (i - start_idx) // thinning
    322         collection = cond(

~/Anaconda/anaconda3/envs/spyder-env/lib/python3.9/site-packages/numpyro/infer/ in _sample_fn_nojit_args(state, sampler, args, kwargs)
    172 def _sample_fn_nojit_args(state, sampler, args, kwargs):
    173     # state is a tuple of size 1 - containing HMCState
--> 174     return (sampler.sample(state[0], args, kwargs),)

~/Anaconda/anaconda3/envs/spyder-env/lib/python3.9/site-packages/numpyro/infer/ in sample(self, state, model_args, model_kwargs)
    758         :return: Next `state` after running HMC.
    759         """
--> 760         return self._sample_fn(state, model_args, model_kwargs)
    762     def __getstate__(self):

~/Anaconda/anaconda3/envs/spyder-env/lib/python3.9/site-packages/numpyro/infer/ in sample_kernel(hmc_state, model_args, model_kwargs)
    468         )
    469         # not update adapt_state after warmup phase
--> 470         adapt_state = cond(
    471             hmc_state.i < wa_steps,
    472             (hmc_state.i, accept_prob, vv_state, hmc_state.adapt_state),

~/Anaconda/anaconda3/envs/spyder-env/lib/python3.9/site-packages/numpyro/ in cond(pred, true_operand, true_fun, false_operand, false_fun)
    117             return false_fun(false_operand)
    118     else:
--> 119         return lax.cond(pred, true_operand, true_fun, false_operand, false_fun)

    [... skipping hidden 16 frame]

~/Anaconda/anaconda3/envs/spyder-env/lib/python3.9/site-packages/numpyro/infer/ in <lambda>(args)
    471             hmc_state.i < wa_steps,
    472             (hmc_state.i, accept_prob, vv_state, hmc_state.adapt_state),
--> 473             lambda args: wa_update(*args),
    474             hmc_state.adapt_state,
    475             identity,

~/Anaconda/anaconda3/envs/spyder-env/lib/python3.9/site-packages/numpyro/infer/ in update_fn(t, accept_prob, z_info, state)
    679             )
--> 681         t_at_window_end = t == adaptation_schedule[window_idx, 1]
    682         window_idx = jnp.where(t_at_window_end, window_idx + 1, window_idx)
    683         state = HMCAdaptState(

    [... skipping hidden 1 frame]

~/Anaconda/anaconda3/envs/spyder-env/lib/python3.9/site-packages/jax/_src/numpy/ in _rewriting_take(arr, idx, indices_are_sorted, unique_indices, mode, fill_value)
   5642   arr = asarray(arr)
   5643   treedef, static_idx, dynamic_idx = _split_index_for_jit(idx, arr.shape)
-> 5644   return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
   5645                  unique_indices, mode, fill_value)

~/Anaconda/anaconda3/envs/spyder-env/lib/python3.9/site-packages/jax/_src/numpy/ in _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted, unique_indices, mode, fill_value)
   5669   # We avoid generating a gather when indexer.gather_indices.size is empty.
   5670   if not core.is_empty_shape(indexer.gather_indices.shape):
-> 5671     y = lax.gather(
   5672       y, indexer.gather_indices, indexer.dnums, indexer.gather_slice_shape,
   5673       unique_indices=unique_indices or indexer.unique_indices,

    [... skipping hidden 4 frame]

~/Anaconda/anaconda3/envs/spyder-env/lib/python3.9/site-packages/jax/interpreters/ in _assert_live(self)
   1170   def _assert_live(self) -> None:
   1171     if not self._trace.main.jaxpr_stack:  # type: ignore
-> 1172       raise core.escaped_tracer_error(self, None)
   1174 class JaxprStackFrame:

UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with shape (1, 2) and dtype int32 to escape.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was _single_chain_mcmc at /Users/jeroenuleman/Anaconda/anaconda3/envs/spyder-env/lib/python3.9/site-packages/numpyro/infer/ traced for pmap.
The leaked intermediate value was created on line /Users/.../Anaconda/anaconda3/envs/spyder-env/lib/python3.9/site-packages/numpyro/infer/ (warmup_adapter). 
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
/Users/.../Anaconda/anaconda3/envs/spyder-env/lib/python3.9/site-packages/numpyro/infer/ (_single_chain_mcmc)
/Users/.../Anaconda/anaconda3/envs/spyder-env/lib/python3.9/site-packages/numpyro/infer/ (init)
/Users/.../Anaconda/anaconda3/envs/spyder-env/lib/python3.9/site-packages/numpyro/infer/ (<lambda>)
/Users/.../Anaconda/anaconda3/envs/spyder-env/lib/python3.9/site-packages/numpyro/infer/ (init_kernel)
/Users/.../Anaconda/anaconda3/envs/spyder-env/lib/python3.9/site-packages/numpyro/infer/ (warmup_adapter)

To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.

Also, @fehiepsi, the following gives another error as well.

first_state_array = jax.device_get(mcmc_stocks.post_warmup_state), data_bl, data_t, K, ind_mat, ts, num_params, constants, aux_predictors_mat, age_per_month, prior)


AttributeError                            Traceback (most recent call last)
/var/folders/58/rcr3xzjn6bscyr8llj4t161m0000gn/T/ipykernel_81238/ in <module>
      1 start = time.time()
----> 2, data_bl, data_t, K, ind_mat,   #mcmc_stocks.post_warmup_state.rng_key
      3                             ts, num_params, constants, aux_predictors_mat,
      4                             age_per_month, prior)
      5 print("NUTS took ", round((time.time()-start)/60, 2), " minutes.")  # or

~/Anaconda/anaconda3/envs/spyder-env/lib/python3.9/site-packages/numpyro/infer/ in run(self, rng_key, extra_fields, init_params, *args, **kwargs)
    540         self._kwargs = kwargs
    541         init_state = self._get_cached_init_state(rng_key, args, kwargs)
--> 542         if self.num_chains > 1 and rng_key.ndim == 1:
    543             rng_key = random.split(rng_key, self.num_chains)

AttributeError: 'HMCState' object has no attribute 'ndim'

How would I be able to convert the first_state_array back to something that can work work?

Thank you!

I’m not sure what causes the issue. Could you make a simple reproducible colab notebook for this? Pls make sure to provide random key as the first argument of the run method.


Ah yes it works with the random key.
Regarding the first point, I also got it to work by switching the chain_method from ‘parallel’ to ‘vectorized’. I’m now able to sample repeatedly in a for-loop and store the intermediate results, even with 4 chains! :slight_smile: And it’s quite a bit faster than when I try to sample without interruption.

One more question about that. Is the warmup phase different from the regular sampling phase, or can I just repeatedly get 50 samples over 4 chains (say 40 times = 2000 total samples) with 0 warm up samples, and then in the end manually discard the first half of the samples (i.e., the first 1000 samples for each chain)?

This seems practical because running 1000 warm up samples before storing the result seems impossible for me. I’ve not been able to sample more than about 200 samples in a row.

The only problem seems to be that mcmc.print_results() then returns absurd r_hat and n_eff, I think because it includes all samples, including the samples that I would consider ‘warmup’ samples and discard.

Glad that the vectorize method works (I guess parallel method might work if you use device_get and block until ready properly Asynchronous dispatch — JAX documentation). During warmup, step size and mass matrix are adapted so if you don’t need adaptation scheme, warmup and sample phases should be the same. Regarding print summary, it should print out stats of the last run. If it’s not what you expected, you can use various utilities in numpyro.diagnostics.

(Btw, it would be always easier to talk about code behavior with some reproducible code.)

1 Like

Thanks a lot @fehiepsi, i really appreciate your input.
Although I got it to work, the sampler remains very very slow, and I’ve yet to converge to a result with acceptable eff samples and rhat stats.

Not sure if I should open a new topic about this, but I’ve been reading a bit about bad posterior geometries and wonder to what extent the poor sampling could be caused by strong correlations between the parameters.

I have been able to do some MAP optimizations and noticed some very very high correlations between parameters (many >0.5 and a few even higher than 0.9). I tried reducing the posterior space, even by fixing most of the parameters to MAP values, and only sampling from a few, but the sampling was still very inefficient.

Do you have any thoughts on this? Potentially I could use sensitivity analyses to fix less relevant parameters to MAP values and only sample from the most important ones, or something like that.