Custom JVP is breaking NUTS inference

Hello devs. I am trying to fit a type of sigmoid curve through my data using Bayesian regression. If I define a custom JVP for this function, SVI (with AutoNormalGuide) works just as well compared to the case when I don’t define custom JVP. However, MCMC NUTS breaks with custom JVP (bad rhats, originally it works well without custom JVP). I would like to use a custom jvp for more complex functions…

Could you please give me some idea of what could be going wrong? I tried reading through the source code of NUTS, but I couldn’t find where exactly the differentiation call is made.

Hi @mathlad, this line calls grad. I think you can play with the grad of the log_density utility first to see what’s wrong.