Simulated Annealing with NUTS to optimize a jitted potential function

I cannot figure out how to implement a NUTS based simulated annealing algorithm to approximate the maximum of an arbitrary potential function. More specifically, the potential function requires a sequential computation therefore it needs to be jitted. In order to run a simulated annealing optimizer, there needs to exist a manually controlled temperature parameter to the potential function, whereas the rest of the parameters needs to be sampled via MCMC. Moreover, this temperature parameter can be adjusted at each iteration, and the acceptance ratio should be calculated using the latest temperature value. How can I achieve so?

Finding the maximum likelihood sounds like a raw optimization problem rather than an MCMC problem. NUTS requires the likelihood function to be static, but simulated annealing necessarily requires it to change. Using the NumPyro inference utilities you could easily build a callable temperature dependent potential, but I don’t think the infer.MCMC samplers will be able to use it in the way you want. Instead, could you look into jax’s more direct stochastic optimizers? Maybe use the SVI module, which is a sort of middle ground between the two?