Numpyro + Flax?

Hi!
I am just writing for a quick doubt . Flax is currently on pre-release state but hopefully soon to be fully released. Jax is on version 1.41 currently, (and Flax is only properly working with the latest version of Jax) while Numpyro right now is working with 1.37 version. Do you have an estimate of when everything could be working together, updated? Just wondering :blush:, I am aware is a lot of work.

Thanks and take care! :slight_smile:

@artistworking - We should be able to update NumPyro soon so that it works with JAX 1.41. Unless there are some breaking changes in JAX, it might already work but I’ll need to check. Beyond that, I don’t suppose we need to do anything else to support libraries like Flax. Is there anything in particular that you are looking to do with Flax and NumPyro?

Hi!

Thanks ! :slight_smile: I am working on a generative model that involves HMC(NUTS) and RNN/GRU/LSTM .

I think the easiest way is to write a potential_fn, which takes inputs are your “required-to-sample” parameters. If you want to add Normal() prior for a parameter “weight”, just simply add Normal().log_prob(weight) to the joint density. You can also write helpers to make this job easier for you.

Currently, NumPyro supports Stax modules to optimize parameters using SVI. To get inference about parameters of nn modules in MCMC, it is better to use potential_fn as above (this applies for Stax, Flax, Haiku,…)

Whoao, ok. I think I have some reading to do them :slight_smile: . Let’s see what I manage to do. I am implementing the model in pyro first, to see if the idea works and because I have more experience with it and then I will transfer to numpyro. Thanks again!!!