Problem solved. By looking the optimization library I figured out that one can pass a callable that maps the iteration number to the step_size. For example, for linear decay I defined
def linear_decay(iteration, init = 0.001, end=0.0001, n_steps = 10000): return jnp.linspace(init,end,n_steps)[iteration]
and it seems to work.
for additional examples i suggest you search the repo for different examples of optax usage e.g. here