Change step_size in Adam optimizer during SVI training

I found this post but it is not clear from where I should import the function exponential_decay or how I should define it. The link suggested previous post by @ fehiepsi is broken.

Problem solved. By looking the optimization library I figured out that one can pass a callable that maps the iteration number to the step_size. For example, for linear decay I defined

def linear_decay(iteration, init = 0.001, end=0.0001, n_steps = 10000):
    return jnp.linspace(init,end,n_steps)[iteration]

and it seems to work.

for additional examples i suggest you search the repo for different examples of optax usage e.g. here

1 Like