Custom action by detecting warmup stage / certain number of iterations

During MCMC NUTS / HMCGibbs iteration, is it possible to detect whether

  • the iteration is on warmup
  • the iteration is on the n-th run

so that I can make special action/strategy based on that?

I think you can subclass the kernel and wrap the sample method - there you can extract info from the input state Markov Chain Monte Carlo (MCMC) — NumPyro documentation