I’m having a bit of trouble figuring out how to use the ReduceLROnPlateau scheduler - has anybody got this to work? There’s an example in the docs here that shows how to construct an ExponentialLR scheduler:
optimizer = torch.optim.SGD
pyro_scheduler = pyro.optim.ExponentialLR({'optimizer': optimizer, 'optim_args': {'lr': 0.01}, 'gamma': 0.1})
so I constructed my scheduler like this:
scheduler = pyro.optim.ReduceLROnPlateau({
'optimizer': torch.optim.Adam, 'optim_args': {'lr': 0.1}
})
Then I try to do SVI as follows:
svi = pyro.infer.SVI(model, guide, scheduler, pyro.infer.Trace_ELBO())
svi.step(inputs, labels)
but I get this error
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-106-6ecba3045e51> in <module>()
----> 1 svi.step(train_inputs, train_labels)
~/anaconda/lib/python3.6/site-packages/pyro/infer/svi.py in step(self, *args, **kwargs)
80 # actually perform gradient steps
81 # torch.optim objects gets instantiated for any params that haven't been seen yet
---> 82 self.optim(params)
83
84 # zero gradients
~/anaconda/lib/python3.6/site-packages/pyro/optim/lr_scheduler.py in __call__(self, params, *args, **kwargs)
32 def __call__(self, params, *args, **kwargs):
33 kwargs['epoch'] = self.epoch
---> 34 super(PyroLRScheduler, self).__call__(params, *args, **kwargs)
35
36 def _get_optim(self, params):
~/anaconda/lib/python3.6/site-packages/pyro/optim/optim.py in __call__(self, params, *args, **kwargs)
56
57 # actually perform the step for the optim object
---> 58 self.optim_objs[p].step(*args, **kwargs)
59
60 def get_state(self):
TypeError: step() missing 1 required positional argument: 'metrics'
If I just use pyro.optim.Adam, svi.step works fine. I assume the “metrics” that are missing are what the scheduler uses to determine when to reduce the LR (because scheduler doesn’t have a step method, so I presume that is wrapped up in svi.step), but I’m not sure how these should be formatted.
I tried modifying my model and guide to both accept an additional metrics argument which they just ignore (otherwise, passing anything else to svi.step gives me an error that guide() got an unexpected keyword argument 'metrics'). But even if I call svi.step(inputs, labels, metrics={'loss': 5}) or svi.step(inputs, labels, metrics=5), I still get the same error TypeError about metrics being missing.
How should I use this scheduler together with SVI? In particular, how do I structure the call to svi.step?