Store intermediate samples

Hi everyone, I’m dealing with very slow NUTS/MCMC sampling of a large system of differential algebraic equations. Even if I just run one chain, the progress bar often initially tells me the process will take 2 hours, but then becomes 13 hours later on (e.g., at 15% or so) and I don’t really trust the progress bar anymore.

I’m wondering what I can do to make the process more feasible because I’m not always able to let the notebook run indefinitely.

Right now I’m using Jupyter lab, is that recommended for long runs like this or would it be better to run from python script in terminal or Spyder or so? (MacOs)

I was also wondering if it’s possible to store intermediate samples during the sampling process (or after, but then I would run for fewer samples), so that I can interrupt the process and start again later?

Finally, is this unreliability of the progress bar a known issue, or is something going wrong on my side?

the issue isn’t the progress bar. it’s the fact that NUTS is adaptive. at any given point in the progress bar it makes a computation along the lines of “i’m at iteration 100; on averaged i’ve used 128 gradient steps per iteration; consequently my ETA is XXX”. however, during adaptation (as the mass matrix and step size changes) and during inference (as different regions of the posterior are explored) the number of gradient steps per iteration will adapt. consequently you can’t expect a perfect prediction on completion time, especially before you’ve completed warm-up.

Thanks for the explanation!

Do you have any idea whether I should put more trust in the optimistic (early) or pessimistic (later) result?

And would it be possible to store intermediate results or somehow access the samples from a chain that I was running but is taking too long to complete? It’s really frustrating when I’ve run 200 samples but the other 800 are taking too long to complete. Seems to me that I could still use the 200 samples and mix them with another 800 that are sampled later on? (Assuming in both cases the chains were sufficiently warmed up?)

in my experience the completion estimate that you get after warm-up finishes is pretty accurate.

because of the delicate nature of the adaptation algorithms i think it’s probably a bit tricky to do start-stop-restart-like workflows but other forum posts or @fehiepsi may be able to point you towards possible solutions

Yeah, you can search post_warmup_state for run-stop-continue workflow.

Thanks! That’s very helpful.

I got a bit confused about this part of the documentation
mcmc.post_warmup_state = mcmc.last_state

The ‘’ post_warmup_state ‘’ is the first state after the warm-up samples right? Or is it simply the last state that was sampled. So if I do 100 warmup runs and 100 samples (200 total), does post_warmup_state then give me the state at 100 runs, or at 200 runs?

And is this state simply equal to the set of parameters sampled at that particular point? In that case, is this initializing with ‘‘post_warmup_state’’ the same as simply storing the last sampled parameter values and then initializing the next sequence with these (i.e., using init_to_values()).

Alternatively, If the post_warmup_state contains more information than that, is there a convenient way of storing this state to my harddrive so that I can load and use it later?

Thanks again!

post_warmup_state is a state at some point after the warmup phase, could be the first one or the 1000th one, e.g. if you set mcmc.post_warmup_state = mcmc.last_state then it will be the last state of your chain. The chain will start at that post warmup state. It contains more information than just samples, like step_size, mass matrix, grad, potential energy,… you can use jax.device_get(state) to convert device array to numpy array and store it.

1 Like

Thanks so much! Ill try to work with that, I think it can save me a lot of time to store the warmed up state.

@martinjankowiak, just to quickly come back to your explanation about progress bar. What would be a reasonable speed to actually expect from NUTS?

Does it evaluate my likelihood function once for each parameter (for autodiff I guess?) for each iteration? I’m surprised that it’s so slow because to evaluate my likelihood function should only take a few seconds, so even if it evaluated it once for each parameter (say that takes a minute, or even two minutes), then 2 hours for 30 samples is too long.

Also, wouldn’t that mean that when dealing with large number of parameters it’s actually more efficient to use a gradient-free method like metropolis, even if you do need many more samples in total?

the progress bar reports things like “XXX steps of size YYY”

this tell you how many gradient steps per iteration it’s doing. for hard problems this will tend to max out at 2^{max_tree_depth}. max_tree_depth defaults to 10. so a hard problem may require millions of gradient evaluations to get a few thousand samples.

no, gradient-free methods will almost always be much less efficient, except maybe in a small handful of dimensions

1 Like

Thanks that’s helpful.

I just had a run that took 109.28 minutes (1:49:15 h).

It reports:
50/50 131.11s/it, 255 steps of size 3.02e-02.

So 131.11s/it is clearly the amount of seconds per ‘sample’, and if I understand correctly it takes 255 gradient steps in this time. Does that mean it evaluated my likelihood function 255*50=12,750 times? That would mean it takes about half a second. Sounds reasonable to me.

yes that’s right