I’m having a bit of trouble figuring out how to use the ReduceLROnPlateau
scheduler - has anybody got this to work? There’s an example in the docs here that shows how to construct an ExponentialLR
scheduler:
optimizer = torch.optim.SGD
pyro_scheduler = pyro.optim.ExponentialLR({'optimizer': optimizer, 'optim_args': {'lr': 0.01}, 'gamma': 0.1})
so I constructed my scheduler like this:
scheduler = pyro.optim.ReduceLROnPlateau({
'optimizer': torch.optim.Adam, 'optim_args': {'lr': 0.1}
})
Then I try to do SVI as follows:
svi = pyro.infer.SVI(model, guide, scheduler, pyro.infer.Trace_ELBO())
svi.step(inputs, labels)
but I get this error
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-106-6ecba3045e51> in <module>()
----> 1 svi.step(train_inputs, train_labels)
~/anaconda/lib/python3.6/site-packages/pyro/infer/svi.py in step(self, *args, **kwargs)
80 # actually perform gradient steps
81 # torch.optim objects gets instantiated for any params that haven't been seen yet
---> 82 self.optim(params)
83
84 # zero gradients
~/anaconda/lib/python3.6/site-packages/pyro/optim/lr_scheduler.py in __call__(self, params, *args, **kwargs)
32 def __call__(self, params, *args, **kwargs):
33 kwargs['epoch'] = self.epoch
---> 34 super(PyroLRScheduler, self).__call__(params, *args, **kwargs)
35
36 def _get_optim(self, params):
~/anaconda/lib/python3.6/site-packages/pyro/optim/optim.py in __call__(self, params, *args, **kwargs)
56
57 # actually perform the step for the optim object
---> 58 self.optim_objs[p].step(*args, **kwargs)
59
60 def get_state(self):
TypeError: step() missing 1 required positional argument: 'metrics'
If I just use pyro.optim.Adam
, svi.step
works fine. I assume the “metrics” that are missing are what the scheduler uses to determine when to reduce the LR (because scheduler
doesn’t have a step
method, so I presume that is wrapped up in svi.step
), but I’m not sure how these should be formatted.
I tried modifying my model
and guide
to both accept an additional metrics
argument which they just ignore (otherwise, passing anything else to svi.step
gives me an error that guide() got an unexpected keyword argument 'metrics'
). But even if I call svi.step(inputs, labels, metrics={'loss': 5})
or svi.step(inputs, labels, metrics=5)
, I still get the same error TypeError
about metrics being missing.
How should I use this scheduler together with SVI
? In particular, how do I structure the call to svi.step
?