Initialize each chain of MCMC separately

Hi!

I am running NUTS in a setting where data increases over time. What I would like to do is to is to initialize the new chains with the last sample from each of the previous chains. This works fine when I only have one chain as I can extract the last sample and use init_strategy = numpyro.infer.util.init_to_value(values=lastsample) in my NUTS kernel.

This approach does not work for multiple chains as util.init_to_value does not seem to support initializing different chains with different sets of parameters. Is there any way to get this to work that I am unaware of? Or should this just be submitted as a feature request on the github page?

I tried to show the issue in a google colab below. The final two cells summarize the issue and the unsatisfactory solution of just initializing the chains to the parameters of one chain.

I think post_warmup_sample is what you need. It will skip warmup phase on the next run with the new data. If you need to trigger warmup phase in the next run, I guess you can do

mcmc.post_warmup_state = mcmc.last_state.replace(i=jnp.array(0))

Thanks! this looks like what I am looking for.