Is `callback` supported for numpyro NUTS sampler?

Hi all,

I’m using Numpyro as a NUTS sampler in a PyMC model and would like to add a callback to monitor the number of divergences and stop the sampling when it’s greater than X amount of divergences.

I found this response in the PyMC and want to know if that has changed or if there is a way to do it now.

This could be only in PyMC but already asked there #7419
:slight_smile:

sample_numpyro_nuts does track divergences and other sampling metrics, though it is true that they are not reported on the fly.

Also, looking for an example of how I implement this as I’m fairly new the Bayesian framework and have zero experience with numpyro. :slight_smile:

Thanks!

Currently, we don’t support the callback but it could be supported. You can use the low level api, which provides more control for what you want to achieve Example: Stochastic Volatility — NumPyro documentation

Thanks @fehiepsi ! Do you have an example of how the callback can be implemented? I suspect it can be done with fori_collect but my knowledge of numpyro is quite limited to understanding what is going on in the example you shared.

Thank you!

I think you can run a Python loop over the sample_kernel to collect samples and do any callback as you like.