Hi!
I am running NUTS in a setting where data increases over time. What I would like to do is to is to initialize the new chains with the last sample from each of the previous chains. This works fine when I only have one chain as I can extract the last sample and use init_strategy = numpyro.infer.util.init_to_value(values=lastsample) in my NUTS kernel.
This approach does not work for multiple chains as util.init_to_value does not seem to support initializing different chains with different sets of parameters. Is there any way to get this to work that I am unaware of? Or should this just be submitted as a feature request on the github page?
I tried to show the issue in a google colab below. The final two cells summarize the issue and the unsatisfactory solution of just initializing the chains to the parameters of one chain.