I found this post but it is not clear from where I should import the function exponential_decay
or how I should define it. The link suggested previous post by @ fehiepsi is broken.
Problem solved. By looking the optimization library I figured out that one can pass a callable that maps the iteration number to the step_size. For example, for linear decay I defined
def linear_decay(iteration, init = 0.001, end=0.0001, n_steps = 10000):
return jnp.linspace(init,end,n_steps)[iteration]
and it seems to work.
for additional examples i suggest you search the repo for different examples of optax usage e.g. here
1 Like