Hi all! I wanted to share some material I presented at PyData Berlin 2025, introducing variational inference with NumPyro.
- Slides Scaling Probabilistic Models with Variational Inference
- Notebook: PyData Berlin 2025: Introduction to Stochastic Variational Inference with NumPyro - Dr. Juan Camilo Orduz
In particular, I show a case with a little example:
- New Flax NNX integration
- Custom Optax optimizers
- Early stopping (custom training loop)
Many folks working with NumPyro were not aware of these features, so I think we should highlight them more to the community.
Feedback is welcome