Bug in LR_scheduler?

I was trying to implement a lr_range_test and in the process, I believe I found a bug in the PyroLRScheduler.

I have created a simple VAE and trained on FashionMNIST both in Pyro and Pytorch. In all cases I have used RMSprop as optimizer and StepLR as a scheduler which at every epoch change the LR by a factor gamma.

I have done the following checks:

  1. for LR=1E-4 and gamma=1.0 (i.e. LR is constant) both Pyro and Pytorch implementation work great.

  2. for LR=1E-6 and gamma=1.5 (i.e. LR increases exponentially) the Pytorch implementation works for 22 epochs when the LR is too large (1E-2) and the loss becomes Nan. This is the expected behavior.

  3. the same setup as before, i.e. for LR=1E-6 and gamma=1.5, in the Pyro implementation produces nan from the start (i.e. not even a single epoch runs successfully)

I have done other tests (i.e. changing the in initial LR and gamma) the conclusion is that, unless gamma=1.0, the implementation with the PyroScheduler produces nan even before the first epoch (which is the earliest time at which the scheduler should change the learning rate).

I do not believe that this is a mistake in my code since the code runs perfectly when gamma=1.0. The problem must be related to the scheduler.

Below are the links to the Pytorch vs Pyro comparison for the case gamma=1.0 and gamma=1.5.

gamma=1.5
gamma=1.0

Should I open a bug report?

hi @luca, what version of pyro are you running? also, what is the implementation of train()? because you have a lot of import *s it’s hard to figure out what’s happening behind the scenes. are you stepping the scheduler as well?

Hello,
thanks for checking into this.

I am using pyro 0.3.3 .
The implementation of “train” is copied below.

def train(svi, loader, use_cuda=False,verbose=False):
    epoch_loss = 0.
    for x, _ in loader:
        # if on GPU put mini-batch into CUDA memory
        if use_cuda:
           x = x.cuda()

       loss = svi.step(x)
       if(verbose):
          print("loss=%.5f" %(loss))
       epoch_loss += loss

   return epoch_loss / len(loader.dataset) 

You can actually checkout my entire code in this github repository.
In the pyro implementation I am not stepping the scheduler since this post “LR scheduler usage” says that this is done automatically.

please use the latest version 0.3.4, and note that you’ll have to call scheduler.step() after svi.step()

I managed to make the scheduler work with version 0.3.4.
I am copying below the code to obtain the lr_range_finder in case this of interest to other people:

def one_epoch_train(svi, loader, use_cuda=False):
    epoch_loss = 0.
    for x, _ in loader:
       if use_cuda:
           x = x.cuda()
       loss = svi.step(x)
       epoch_loss += loss
   return epoch_loss / len(loader.dataset) 

optimizer_args = {'lr': 1E-6} 
scheduler_args = {'optimizer': torch.optim.RMSprop, 
                  'step_size' : 1, 'gamma' : 1.5, 
                  'optim_args' : optimizer_args}
pyro_scheduler = pyro.optim.StepLR(scheduler_args)

svi = SVI(vae.model, vae.guide, pyro_scheduler, loss=Trace_ELBO())

# Loop to compute lr_range_finder
hist_loss,hist_lr = [],[]
for epoch in range(0,N_epoch): 
   svi.optim.step(epoch=epoch) #set the epoch in the LR_scheduler
   loss_curr = one_epoch_train(svi, loader, use_cuda=use_cuda, verbose=False)
   lr_curr = next(iter(svi.optim.optim_objs.values())).get_lr()[0]  # get the current LR
   hist_loss.append(loss_curr)  
   hist_lr.append(lr_curr)

plt.plot(hist_lr,hist_loss)

Note that:

  1. svi.step() is called once for minibacth
  2. scheduler.step() is called by svi.optim.step(epoch=epoch) once for epoch

there are some bugs with the pytorch scheduler that they are fixing for the next release, but note that if you are using torch 1.1, the order of the scheduler and optimizer stepping is flipped, so the recommended way is to do:

for epoch in range(epochs): 
    for minibatch in data:
        svi.step()
        scheduler.step(epoch=epoch)

or just step it once per epoch (since that is the only time the scheduler changes, depending on which you’re using)

for epoch in range(epochs): 
    for minibatch in data:
        svi.step(minibatch)
    scheduler.step()