Initialize each chain of MCMC separately

Hi!

I am running NUTS in a setting where data increases over time. What I would like to do is to is to initialize the new chains with the last sample from each of the previous chains. This works fine when I only have one chain as I can extract the last sample and use init_strategy = numpyro.infer.util.init_to_value(values=lastsample) in my NUTS kernel.

This approach does not work for multiple chains as util.init_to_value does not seem to support initializing different chains with different sets of parameters. Is there any way to get this to work that I am unaware of? Or should this just be submitted as a feature request on the github page?

I tried to show the issue in a google colab below. The final two cells summarize the issue and the unsatisfactory solution of just initializing the chains to the parameters of one chain.

I think post_warmup_sample is what you need. It will skip warmup phase on the next run with the new data. If you need to trigger warmup phase in the next run, I guess you can do

mcmc.post_warmup_state = mcmc.last_state.replace(i=jnp.array(0))

Thanks! this looks like what I am looking for.

Hello,

I’m trying to do sequential MCMC too. The motivation is similar to this thread and this other thread where I have sequential data coming in and would like to update from the last state, i.e., doing this.

I would appreciate

  1. being able to initialize the values for my multiple chains separately (and potentially initialize the inverse mass matrix separately as well in the future).
  2. being able to rerun the warmup steps when new data are coming in.

However, so far I haven’t been able to get it to work.

As in this thread, I can’t use init_to_value since it only allows for one value. There are several problems with using “mcmc.last_state._replace(i=np.zeros(num_chains)).”

  1. It looks like if I only want to initialize the values but not the inverse mass matrix, I’d need to take care of that myself and change all the other details in the last_state. I would deeply appreciate if I could use something as simple as init_to_value.
  2. I’m assuming that the “warmup” step is defined by “i”? Therefore, when I do “mcmc.run,” though it will run for “num_samples” steps, the first “num_warmup” steps would be for warmup? I’m not sure if my understanding is correct. However, it does feel a bit confusing to code it up since I’d need to change “num_samples” from num_samples to num_warmup+num_samples. In addition, when I do “mcmc.get_sample,” I’d need to take the last num_samples samples myself.
  3. I suspect that my understanding about 2) is incorrect because when I look at the inverse_mass_matrix in mcmc.last_state, somehow it is not updated. In particular, this is the experiment I ran:
    mcmc = MCMC(kernel, num_warmup=1, num_samples=20, num_chains=1)
    mcmc.warmup(rng_key, O=O)
    mcmc.post_warmup_state=mcmc.last_state._replace(i=np.array(0))
    mcmc.run(rng_key, O=O)
    print(mcmc.last_state)
    The inverse_mass_matrix is still identity.
    (The above is just an experiment I was running. Eventually the use case would be for the observed data to be changed over time, which would also affect the posterior distribution.)

In short, I would really appreciate if anyone could give me guidance on

  1. how to setup initial values for multiple chains separately with warmup
  2. (for the future) how to change the inverse_mass_matrix properly with warmup

Thank you so much! I really deeply appreciate it.

Hi @janelai22, I think the functional api provides flexibility that you needed:

  • First, use initialize_model utility with dynamic_args=True (because you want to change the data later) to get model_info
  • First, call hmc with potential_fn_gen=model_info.potential_fn is the to get init_kernel and sample_kernel
  • vmap init_kernel over init_params and inverse_mass_matrix to get init states
  • vmap/pmap fori_collect(warmup, warmup + num_samples, partial(sample_kernel, model_args=model_args, model_kwargs=model_kwargs)) with the init_state from the previous step
  • get the last state, modify it and run fori_collect again

If you want to use MCMC api, then we can expose init_state to MCMC.run api and use it at those lines to MCMC.run. In addition, to setup different initial values for different chains, you will need to provide init_params to MCMC.run and make a FR/PR to not overwrite init_params here if it is not None. After those, the code would be

mcmc = MCMC(...)
# (optionally) run those statements to update inverse mass matrix
# mcmc._compile(..., init_params=batched_unconstrained_init_values)
# init_state = mcmc.last_state
# init_state = update_inverse_mass_matrix(init_state)
mcmc.run(..., init_params=batched_init_values)
next_init_state = mcmc.last_state
mcmc.run(..., init_state=next_init_state)

Thank you so much!!! This is so helpful. The first option looks good and thanks for giving me so many details. It looks like after I changed “dynamic_args=True,” I need to pass “model_kwargs” into “fori_collect” as you suggested. However, I’m not entirely sure how to pass that info in. I’m just testing with one chain. I now have something like this:
hmc_states = fori_collect(
num_warmup,
num_warmup + num_samples,
sample_kernel,
hmc_state,
model_kwargs={‘T_max’:T_max, ‘R0’:R0, ‘A’:A, ‘O’:O, ‘RT_max’:RT_max}
)
And it gave me errors. Could you perhaps elaborate one how I should pass this information?

As for the second option, I have a quick clarification question. What do you mean by FR/PR? My previous attempt on init_params had no impact, i.e., the result looked the same regardless of whether I put “init_params” in “mcmc.run(…).” I believed something has overwritten it.

Thank you so much again! I really appreciate it!

Hi @janelai22 I think you need partial(sample_kernel, model_args=model_args, model_kwargs=model_kwargs) as in my last comment. By FR/PR, I meant feature request or pull request at https://github.com/pyro-ppl/numpyro

Hi,

Thank you so much! Sorry that I was confused but this was clear now! However, I got an error as shown here in the last cell. It worked out fine if I set dynamic_args=False. I’m not familiar with JAX. Would you mind telling me what might have been the cause of the error?

Thanks so much! Sorry for the delay.

I guess you need this

call hmc with potential_fn_gen=model_info.potential_fn is the to get init_kernel and sample_kernel

1 Like

It is working nicely!!! Thank you so much!