How to use `ReduceLROnPlateau` scheduler

I’m having a bit of trouble figuring out how to use the ReduceLROnPlateau scheduler - has anybody got this to work? There’s an example in the docs here that shows how to construct an ExponentialLR scheduler:

optimizer = torch.optim.SGD
pyro_scheduler = pyro.optim.ExponentialLR({'optimizer': optimizer, 'optim_args': {'lr': 0.01}, 'gamma': 0.1})

so I constructed my scheduler like this:

scheduler = pyro.optim.ReduceLROnPlateau({
  'optimizer': torch.optim.Adam, 'optim_args': {'lr': 0.1}
})

Then I try to do SVI as follows:

svi = pyro.infer.SVI(model, guide, scheduler, pyro.infer.Trace_ELBO())
svi.step(inputs, labels)

but I get this error

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-106-6ecba3045e51> in <module>()
----> 1 svi.step(train_inputs, train_labels)

~/anaconda/lib/python3.6/site-packages/pyro/infer/svi.py in step(self, *args, **kwargs)
     80         # actually perform gradient steps
     81         # torch.optim objects gets instantiated for any params that haven't been seen yet
---> 82         self.optim(params)
     83 
     84         # zero gradients

~/anaconda/lib/python3.6/site-packages/pyro/optim/lr_scheduler.py in __call__(self, params, *args, **kwargs)
     32     def __call__(self, params, *args, **kwargs):
     33         kwargs['epoch'] = self.epoch
---> 34         super(PyroLRScheduler, self).__call__(params, *args, **kwargs)
     35 
     36     def _get_optim(self, params):

~/anaconda/lib/python3.6/site-packages/pyro/optim/optim.py in __call__(self, params, *args, **kwargs)
     56 
     57             # actually perform the step for the optim object
---> 58             self.optim_objs[p].step(*args, **kwargs)
     59 
     60     def get_state(self):

TypeError: step() missing 1 required positional argument: 'metrics'

If I just use pyro.optim.Adam, svi.step works fine. I assume the “metrics” that are missing are what the scheduler uses to determine when to reduce the LR (because scheduler doesn’t have a step method, so I presume that is wrapped up in svi.step), but I’m not sure how these should be formatted.

I tried modifying my model and guide to both accept an additional metrics argument which they just ignore (otherwise, passing anything else to svi.step gives me an error that guide() got an unexpected keyword argument 'metrics'). But even if I call svi.step(inputs, labels, metrics={'loss': 5}) or svi.step(inputs, labels, metrics=5), I still get the same error TypeError about metrics being missing.

How should I use this scheduler together with SVI? In particular, how do I structure the call to svi.step?

ReduceLROnPlateau is the only non-supported scheduler because it requires the loss as an arg during training. because of the way the svi object is invoked, there is no easy way to thread the loss through without a lot of custom logic just for that scheduler, so it is currently not supported. all the other ones should be though,

also note youre going to want to use the dev version of pyro since there was a bug in the scheduler that has since been fixed but not yet released.

1 Like

it appears, from looking at the code for pyro-ppl-1.8.5 that ReduceLROnPlateau should now be supported. However, I still see the error: TypeError: <lambda>() got an unexpected keyword argument 'lr'

I am attempting to create the optimizer and scheduler according to the documentation, like so…

pyro.clear_param_store()
learn_rate = 5e-4
optimizer = pyro.optim.Adam
scheduler = pyro.optim.ReduceLROnPlateau( {"optimizer":optimizer, "optim_args": {"lr": learn_rate}, "factor":0.9, "patience":7, "verbose":True})

svi = SVI(my_model, my_guide, scheduler, loss=pyro.infer.Trace_ELBO())

could anyone provide the correct implementation?

i believe the docs say that the optimizer should be a torch.optim one:

optimizer = torch.optim.Adam

though i’m not sure if that’s the issue you’re running into. not sure if your dictionary is constructed correctly (?)

1 Like

Oh!, Nice catch. Thank you. That does seem to fix the error.

generally looking at the tests is a good way to look for correct usage