Writing lower level C code for model when doing SVI

Hi,

I am using the numpyro library to do SVI. To make my code faster, I want to write some of its parts in lower level C/C++.

What all do I need for this? Do you have any references (or examples) I could use?

Thank you,
Atharva

I found this tutorial Extending JAX with custom C++ and CUDA code | Dan Foreman-Mackey . :slight_smile:

3 Likes

I have a similar problem, i.e. I have a forward model written in C++, mapping 3D input array of doubles → 3D output array of doubles, for which I also have the variation w.r.t. the input. I would like to bind it to python and use some of the samplers from numpyro to quickly test what types of samplers would be the best for me to use, maybe then even continue the development and integration on the python side with numpyro.

Now, I am wondering whether this integration with numpyro will be possible (in particular i am interested in the NUTS sampler) after I replicate what is done in dfm’s post above? As far as I see, numpyro uses jax in the backend, so I would have to define the action of my fwd model on jax primitives as well, similarly as dfm has done.

So my question is, is this all that is necessary for a smooth integration with numpyro or would I need to consider doing something extra to bind my C++ fwd model with numpyro?

1 Like

Yes. In addition to the dfm’s post, the following may be useful: Question about defining new JAX primitives · google/jax · Discussion #12730 · GitHub

2 Likes

Thank you! I will take a look and bother you here if I have more questions :slight_smile: