Logging samples during inference

Is there any recommended way to track the samples during MCMC, and log them to tensorboard (or wandb which I use)? This seems helpful if inference takes a long time, and can be used to spot any issues more quickly.

So far I figured I could get somewhere by extending NUTS with some hook functions:

class NUTSWithHooks(NUTS):
    hooks = None

    def set_hooks(self, hooks):
        self.hooks = hooks

    def get_diagnostics_str(self, state):
        if self.hooks is not None:
            for hook in self.hooks:
                hook(state, self)
        return super().get_diagnostics_str(state)

And logging the state using nuts_kernel.set_hooks([log_numpyro_state]):

def log_numpyro_state(state, sampler):
    log_dict = {key: np.mean(value) for key, value in state.z.items()}
    wandb.log(log_dict)

The main issue is that this logs the unconstrained values, while I have constrained values in my model. Also obviously this is a bit of hack, so ideally I am looking for a more neat solution.

Your solution is very nice. For constrained samples, you can use postprocess_fn of the kernel.

1 Like