Save Numpyro MCMC object state for additional sampling


I am trying to save a Numpyro MCMC object and use its last state as warmup for further sampling. I am using mcmc.post_warmup_state = mcmc.last_state as suggested in the documentation, but am getting an error when running the loaded model. The code I am using is below.

# Run model and save to pickle
mcmc = MCMC(NUTS(model), num_warmup=warmup_steps, 
            num_samples=sample_steps, num_chains=1)

with open(r"mcmc_model.pickle", "wb") as output_file:
    pickle.dump(mcmc, output_file)

# Load pickled obj and set warmup to current state
with open(previous_state, "rb") as input_file:
    mcmc = pickle.load(input_file)

mcmc.post_warmup_state = mcmc.last_state

The second call returns a TypeError: 'NoneType' object is not callable. Does anyone know if the issue is due to saving the MCMC model as a pickle obj?

I think this is fixed in dev (see this test).

Thanks @fehiepsi! Just tested and this is fixed in dev.