Using numpyro as a DSL for deterministic inference with INLA

I’m half just sharing a fun application of numpyro and half asking for advice.

I’ve been using the numpyro tools as a domain specific language for specifying models for implementing INLA (link is to a basic tutorial). The main mathematical object needed for INLA is the log likelihood. I’m using something like the code in the code in the handlers docs page to compute log likelihoods. Using numpyro like this for model definition is quite nice for many reasons, but it’s particularly nice to be able to test against MCMC with almost no effort.

Other details:

  • I’m mostly working with low dimensional models (2-10 parameters), but trying to get close to bare metal performance since I need to perform trillions of inferences with these models. numpyro/jax has been surprisingly good for this even after comparing against some highly optimized hand-written CUDA code.
  • As part of INLA, I need to compute hessians. I’ve mostly been trusting JAX to do this, but it seems to have a big impact of jit compilation time.

Overall, I’m wondering:

  • if folks have general advice for implementing non-sampling inference methods on top of numpyro.
    • are there ways to reduce compile times because we don’t need to sample? I’ve noticed that I still need to seed my models despite never sampling from them.
  • any other general advice.

I also just wanted to be a little more looped into the community.

Thanks everyone! numpyro is awesome!!!

hi @tbenthompson i’m not really sure how to answer your question/comment but assuming there’s room for additional speed-ups or ways to reduce compile time arguably the best way to proceed would be to make an INLA PR (or first make a github issue discussing implementation details). that would make it much easier to discuss potential speed-ups. even if the final implementation needs to be modified/stripped down for your particular use case, the result would be something useful for the rest of the community

Hi @martinjankowiak . Sorry to be unclear! Mostly just saying hi and sharing what I’m working on. I’ll consider putting together a small reproduction to demonstrate the compile time. Thanks!!

Hi @tbenthompson, welcome. Regarding your final point, we have a Pyro community Slack instance where some more free-form discussion of research and applications happens. If you send me an email at ebingham@broadinstitute.org, I will invite you.

Regarding INLA in particular, you might be interested in using or drawing from some of the more advanced Gaussian operations we’ve implemented in Pyro and Funsor, though this is maybe less relevant depending on just how extreme your performance constraints are.

For example, funsor.gaussian.Gaussian expressions are general factor graph representations of multivariate Gaussian distributions with block-structured precision matrices, and should be compatible with NumPyro/JAX. pyro.ops.gaussian.Gaussian is more optimized but restricted to PyTorch and to dense precision matrices.

You can see an example of funsor.Gaussian used to implement an inference algorithm in our experimental pyro.infer.autoguide.AutoGaussian autoguide, which generates a multivariate Gaussian variational approximation whose conditional independence structure matches that of the true posterior for a given Pyro model. I believe it would be possible to implement a version of INLA (in Pyro or Numpyro) along similar lines, where the parameters of the approximation are obtained by computing a Hessian at the mode rather than optimizing a variational bound.