Different results with hook_fn and mcmc.get_samples()

I am trying to save intermediate samples while running an MCMC chain. I have a slow model so I would like to check how well is doing while the chain is running.

As a test, I used the eight_schools example and defined a logger function like this:

def logger(kernel, samples, stage, i, logging_dict):
    if stage != "Warmup":
        for key in samples:
            logging_dict[key].append(samples[key])

and then modified the MCMC call like this

    nuts_kernel = NUTS(conditioned_model, jit_compile=args.jit)
    logging_dict = defaultdict(list)
    mcmc = MCMC(
        nuts_kernel,
        num_samples=args.num_samples,
        warmup_steps=args.warmup_steps,
        num_chains=args.num_chains,
        hook_fn=lambda kernel, samples, stage, i: logger(
            kernel, samples, stage, i, logging_dict
        ),
    )

    mcmc.run(model, data.sigma, data.y)
    mcmc.summary(prob=0.5)

However, the results of logging_dict and of mcmc.get_samples() disagree. In particular when running with 10 samples and 2 warmup steps, for the tau parameter, I get a value of -0.0722 logged with the hook_fn and a value of 0.9304 from mcmc.get_samples().

Why are they different? I also noticed that the values recorded with hook_fn can get outside the prior range.

Hi @arnau, my guess is that hook_fn saves unconstrained samples, whereas get_samples() returns the constrained samples. If your samples are e.g. constrained to be positive then I think you could transform them as

from torch.distributions import biject_to, constraints

constrained_samples = biject_to(constraints.positive).inv(unconstrained_samples)

Let me know if that doesn’t work!

1 Like

Hi @fritzo, thanks for your prompt answer!
That was indeed the problem, which also explains why sometimes the samples were outside the prior space. For future reference what I did was:

def logger(kernel, samples, stage, i, logging_dict):
    if stage != "Warmup":
        for key in samples:
            unconstrained_samples = samples[key]
            constrained_samples = kernel.transforms[key].inv(unconstrained_samples)
            logging_dict[key].append(constrained_samples)

Cheers!

1 Like