Checkpoints for `mcmc.run()`

Hey, I was wondering if it is possible to save the mcmc.run() state and samples while sampling with NumPyro? I understand that post_warmup_state can allow to save the warmup state, but I am searching for a solution to the problem of walltime on shared platforms like university clusters, where it would be amazing if the whole chain information was saved every now and then, and then restored if more sampling is needed or the walltime was reached and the job cancelled.

Currently, if the number of samples is not reached before the job is killed, my understanding is that everything is lost.

Many thanks!

Davide

post_warmup_state can also be used to checkpoint. See the example in its docs.

1 Like

Thank you for the reply! I am actually searching for something that checkpoints the current state during sampling, i.e. something that can be restored and continued even if the job was cancelled before the end of the warmup phase which is very long in my problem scenario. Is something like this available?

How about using lower level api fori_collect