Hello!
How can I access warmup chains in numpyro? Here’s part of my code to show how I’ve set collect_warmup=True
:
from numpyro import distributions as dist, infer
sampler = infer.MCMC(
infer.NUTS(model),
num_warmup=num_warmup,
num_samples=num_samples,
num_chains=num_chains,
progress_bar=True)
# (Attempt to) save warmup chains
sampler.warmup(jax.random.PRNGKey(0), time_axis, Y_unc, Y_observed_val=Y_observed, collect_warmup=True)
# Draw samples from the posterior
sampler.run(jax.random.PRNGKey(0), time_axis, Y_unc, Y_observed_val=Y_observed)
# Save results
data = arviz.from_numpyro(sampler)
data.to_netcdf(f"{out_path}/{sn}_{filt}_numpyro.nc")
What I’ve tried so far
- Reading the .nc file, I only see the {posterior, log_likelihood, sample_stats, observed_data} for the sampling phase, not the warmup phase
- When I call
sampler.get_samples()
, this is still just from the sampling phase sampler.get_extra_fields()
, there’s no “warmup” in here either