Sequential HMC Sampling

Hi all,

I have a situation in which sometimes the HMC chains get stuck, for example, one out of the eight parallel chains will be stuck at the same value for the entire run. I was hoping to rerun burn-in again and update just the bad chain with information from one of the new chains (I will then run from this point multiple times again to generate more samples sequentially).

From my understanding, I would need to change the mcmc.last_state attribute which would then be used to update mcmc.post_warmup_state at the next run. Are these the only attributes that I need to update for it to be able to run correctly for later sequential runs? Also, are all subattributes used later for simulation, or do I only need to update a few to add in the good chain (i.e. are both inverse_mass_matrix and mass_matrix_sqrt_inv used)?

Alternatively, if I could extract the specific seed used to generate the good chains then I could possibly run burn-in once again with these values and everything should be set up properly - is that correct/possible?

Thanks in advance!

Great question! Assuming that you want to replace a bad chain 0 using state (except for the random seeds) from a good chain 1, you can do

mcmc.run(...)
last_state = mcmc.last_state
last_state_new = jax.tree_util.tree_map(lambda x: x.at[0].set(x[1]), last_state)
mcmc.post_warmup_state = last_state_new
mcmc.run(new_key)

The state already stores information like mass matrix, step size, etc.

run burn-in once again

Yes, you are right. The initial seeds for each chain is:

rng_keys = random.split(your_rng_key_provided_in_run_method, num_chains)

You can replace initial seed for the bad chain, then do

rng_keys = replace_seed_at_chain_0(rng_keys)
mcmc.run(rng_keys)

Great, thank you! Will try this and let you know if there are any issues.