I am trying to save intermediate samples while running an MCMC chain. I have a slow model so I would like to check how well is doing while the chain is running.
As a test, I used the eight_schools example and defined a logger function like this:
def logger(kernel, samples, stage, i, logging_dict):
if stage != "Warmup":
for key in samples:
logging_dict[key].append(samples[key])
and then modified the MCMC call like this
nuts_kernel = NUTS(conditioned_model, jit_compile=args.jit)
logging_dict = defaultdict(list)
mcmc = MCMC(
nuts_kernel,
num_samples=args.num_samples,
warmup_steps=args.warmup_steps,
num_chains=args.num_chains,
hook_fn=lambda kernel, samples, stage, i: logger(
kernel, samples, stage, i, logging_dict
),
)
mcmc.run(model, data.sigma, data.y)
mcmc.summary(prob=0.5)
However, the results of logging_dict
and of mcmc.get_samples()
disagree. In particular when running with 10 samples and 2 warmup steps, for the tau
parameter, I get a value of -0.0722
logged with the hook_fn and a value of 0.9304
from mcmc.get_samples()
.
Why are they different? I also noticed that the values recorded with hook_fn can get outside the prior range.