Properly using a Scheduler with an optimizer


#1

Just wondering if someone can provide a bit of advice on proper usage of a scheduler with an optimizer and SVI.

Here’s a few snippets of relevant code I’ve tried, along with accompanying error messages:

1

AdamArgs = { 'lr': 1e-3 }
optimizer = Adam
scheduler = pyro.optim.ExponentialLR({'optimizer': optimizer, 'optim_args': AdamArgs, 'gamma': 0.996 })
seqVAE = AEModel.SeqVAE(141, 128, 512, 2, 128, 3)
svi = SVI(seqVAE.model, seqVAE.guide, scheduler, loss=Trace_ELBO())

Throws error:

TypeError: <lambda>() got an unexpected keyword argument 'lr'

2

AdamArgs = { 'lr': 1e-3 }
optimizer = Adam(AdamArgs)
scheduler = pyro.optim.ExponentialLR({'optimizer': optimizer, 'optim_args': AdamArgs, 'gamma': 0.996 })
seqVAE = AEModel.SeqVAE(141, 128, 512, 2, 128, 3)
svi = SVI(seqVAE.model, seqVAE.guide, scheduler, loss=Trace_ELBO())

throws error:

TypeError: step() got an unexpected keyword argument 'lr'

What’s the proper usage here? I thought, in the case of my first example, that I was following the Pyro documentation…

Thanks in advance.

edit:
I should say that if I don’t use a scheduler and instead just the optimizer, everything runs and trains correctly without any errors, hence me not providing additional code.


#2

Ugh, seems like I usually figure out the answer to my question right after caving and posting to a forum about it.

You have to provide an arg: ‘optimizer’ to the pyro.optim.PyroLRScheduler(). This optimizer needs to be a class of torch.optim.Optimizer. But it seems that providing a PyroOptim class isn’t allowed.

This problem was fixed like so:

AdamArgs = { 'lr': 1e-3 }
optimizer = torch.optim.Adam
scheduler = pyro.optim.ExponentialLR({'optimizer': optimizer, 'optim_args': AdamArgs, 'gamma': 0.996 })
seqVAE = AEModel.SeqVAE(141, 128, 512, 2, 128, 3)
svi = SVI(seqVAE.model, seqVAE.guide, scheduler, loss=Trace_ELBO())

Follow up question

In pytorch, you need to call scheduler.step() to step it forward, can do this once per epoch. Is the proper way too update the scheduler in Pyro to call scheduler.set_epoch(INT) ?

Thanks!


#3

Is the proper way too update the scheduler in Pyro to call scheduler.set_epoch(INT) ?

yes

note that ReduceLROnPlateau is the only scheduler not supported.