Customize the likelihood function for NUTS and specify the gradient

Hi! I am trying to use a NUTS sampler with a custom likelihood function. In my case, the likelihood is produced by a trained normalizing flow, which is built with pytorch. I try Pyro first. I specify the potential_fn function, then feed it into the NUTS kernel. It works alright.

However, since I need to add checkpoints in the time-consuming mcmc.run(). I have to switch to NumPyro. Then the gradient information required by NUTS would be a problem since I cannot transform the normalizing flow into a numpy operation. So I am wondering is there any way to custom the likelihood function and specify the gradient of the log_likelihood on the parameters?

Hi @yihaoz, it seems to be tricky. I guess the easier way is to rewrite the normalization flow in flax.