Using probabilistic model in minimization problem

Hi,

I’ve trained a Bayesian Neural Network (BNN) using NumPyro and want to use its predictions in a minimization problem. Specifically, I aim to find input that minimizes an objective function which is a function that contains the pre-trained BNN.

I have the idea of using optax where the cost function uses the mean of the samples that are sampled from the BNN model. I can use jax.grad to compute the gradient of this function. So in theory, I should be able to minimize the function using this approach. But the gradient computation is very slow (had not tried using jax.jit yet). So I think this might not be the appropriate way to approach the problem.

How to do this efficiently?

Thanks for any insight!

Hi @synchronist,

Seems like you know what the first step is (JIT). If you want more advise you’ll need to be more specific and provide a small example that you consider inefficient. You may also want to clarify what you mean by efficient. I assume you want to minimize the runtime based on your own JIT suggestion.

Best, Ola