Learning rate scheduling in numpyro

(How) can one do learning rate scheduling of any kind in numpyro?

pyro provides access to the pytorch schedulers, and the pyro ClippedAdam also has a specific learning rate decay parameter. I can not find anything of the sort in numpyro, however, or any example that does this?

I would even be happy to do this manually, if there is a way to set the learning rate of e.g. numpyro.optim.Adam over the course of the SVI iterations?

We recommend using optax for training jax programs. Please check out its optimizer schedules in the docs. You can use it directly in SVI (instead of numpyro.optim.Adam).

1 Like