I’m half just sharing a fun application of numpyro and half asking for advice.
I’ve been using the numpyro tools as a domain specific language for specifying models for implementing INLA (link is to a basic tutorial). The main mathematical object needed for INLA is the log likelihood. I’m using something like the code in the code in the handlers docs page to compute log likelihoods. Using numpyro like this for model definition is quite nice for many reasons, but it’s particularly nice to be able to test against MCMC with almost no effort.
Other details:
- I’m mostly working with low dimensional models (2-10 parameters), but trying to get close to bare metal performance since I need to perform trillions of inferences with these models. numpyro/jax has been surprisingly good for this even after comparing against some highly optimized hand-written CUDA code.
- As part of INLA, I need to compute hessians. I’ve mostly been trusting JAX to do this, but it seems to have a big impact of jit compilation time.
Overall, I’m wondering:
- if folks have general advice for implementing non-sampling inference methods on top of numpyro.
- are there ways to reduce compile times because we don’t need to sample? I’ve noticed that I still need to seed my models despite never sampling from them.
- any other general advice.
I also just wanted to be a little more looped into the community.
Thanks everyone! numpyro is awesome!!!