Learning rate scheduling in numpyro

(How) can one do learning rate scheduling of any kind in numpyro?

pyro provides access to the pytorch schedulers, and the pyro ClippedAdam also has a specific learning rate decay parameter. I can not find anything of the sort in numpyro, however, or any example that does this?

I would even be happy to do this manually, if there is a way to set the learning rate of e.g. numpyro.optim.Adam over the course of the SVI iterations?

We recommend using optax for training jax programs. Please check out its optimizer schedules in the docs. You can use it directly in SVI (instead of numpyro.optim.Adam).

1 Like

Am I correct that the Reduce on Plateau scheduler is not yet compatible with Numpyro? When I try to use it, it errors, because the previous loss value isn’t being passed.