Is there any recommended way to track the samples during MCMC, and log them to tensorboard (or wandb which I use)? This seems helpful if inference takes a long time, and can be used to spot any issues more quickly.
So far I figured I could get somewhere by extending NUTS with some hook functions:
class NUTSWithHooks(NUTS):
hooks = None
def set_hooks(self, hooks):
self.hooks = hooks
def get_diagnostics_str(self, state):
if self.hooks is not None:
for hook in self.hooks:
hook(state, self)
return super().get_diagnostics_str(state)
And logging the state using nuts_kernel.set_hooks([log_numpyro_state])
:
def log_numpyro_state(state, sampler):
log_dict = {key: np.mean(value) for key, value in state.z.items()}
wandb.log(log_dict)
The main issue is that this logs the unconstrained values, while I have constrained values in my model. Also obviously this is a bit of hack, so ideally I am looking for a more neat solution.