Accessing warmup chains

Hello!

How can I access warmup chains in numpyro? Here’s part of my code to show how I’ve set collect_warmup=True:

from numpyro import distributions as dist, infer

sampler = infer.MCMC(
            infer.NUTS(model),
            num_warmup=num_warmup,
            num_samples=num_samples,
            num_chains=num_chains,
            progress_bar=True)

# (Attempt to) save warmup chains
sampler.warmup(jax.random.PRNGKey(0), time_axis, Y_unc, Y_observed_val=Y_observed, collect_warmup=True)
       
# Draw samples from the posterior
sampler.run(jax.random.PRNGKey(0), time_axis, Y_unc, Y_observed_val=Y_observed)

# Save results
data = arviz.from_numpyro(sampler)
data.to_netcdf(f"{out_path}/{sn}_{filt}_numpyro.nc")

What I’ve tried so far

  • Reading the .nc file, I only see the {posterior, log_likelihood, sample_stats, observed_data} for the sampling phase, not the warmup phase
  • When I call sampler.get_samples(), this is still just from the sampling phase
  • sampler.get_extra_fields(), there’s no “warmup” in here either

You can call get_samples() afer running warmup.

1 Like