Hi,
I’ve trained a Bayesian Neural Network (BNN) using NumPyro and want to use its predictions in a minimization problem. Specifically, I aim to find input that minimizes an objective function which is a function that contains the pre-trained BNN.
I have the idea of using optax
where the cost function uses the mean of the samples that are sampled from the BNN model. I can use jax.grad
to compute the gradient of this function. So in theory, I should be able to minimize the function using this approach. But the gradient computation is very slow (had not tried using jax.jit
yet). So I think this might not be the appropriate way to approach the problem.
How to do this efficiently?
Thanks for any insight!