I found this post but it is not clear from where I should import the function `exponential_decay`

or how I should define it. The link suggested previous post by @ fehiepsi is broken.

Problem solved. By looking the optimization library I figured out that one can pass a callable that maps the iteration number to the step_size. For example, for linear decay I defined

```
def linear_decay(iteration, init = 0.001, end=0.0001, n_steps = 10000):
return jnp.linspace(init,end,n_steps)[iteration]
```

and it seems to work.

for additional examples i suggest you search the repo for different examples of optax usage e.g. here

1 Like