Using probabilistic model in minimization problem

Hi,

I’ve trained a Bayesian Neural Network (BNN) using NumPyro and want to use its predictions in a minimization problem. Specifically, I aim to find input that minimizes an objective function which is a function that contains the pre-trained BNN.

I have the idea of using optax where the cost function uses the mean of the samples that are sampled from the BNN model. I can use jax.grad to compute the gradient of this function. So in theory, I should be able to minimize the function using this approach. But the gradient computation is very slow (had not tried using jax.jit yet). So I think this might not be the appropriate way to approach the problem.

How to do this efficiently?

Thanks for any insight!