Learning rate changes with ReduceLROnPlateau

LS,

I have a question regarding the behaviour of the ReduceLROnPlateau scheduler in
combination with the Adam optimiser. I am training a normalizing flow and I am
testing different training methods. I want the training to be as efficient as
practically possible, which means that i want it to decrease the learning rate
when the training loss is getting lower, to avoid overshooting correct solutions.

My goal is thus to have a scheduler reduce the learning rate by some factor if
the mean loss in the current epoch is some other factor of that of what was
previously regarding the best. The optimiser should otherwise just keep the same
learning rate, at least for several epochs, to give the training a chance.

The obvious choice for this seemed ReduceLROnPlateau, but i find some of the
behaviour slightly curious. If I set the patience to a high number (with the
purpose of the exceeding of the patience not to trigger the learning rate
change), in the cases where the learning rate change should trigger because of
the mean loss falling below the threshold, i do not see the learning rate
change (but num_bad_epochs does seem to indicate something trigger because its
reset). See below for a MWE and the output.

That makes me question whether i am looking up the learning rate correctly.

When I set the patience to be very low (basically 1), learning that to trigger
the change in learning rate, i do see the learning rate change…

If it is the case that I look up the learning-rate correctly, I am wondering if
something is wrong with ReduceLROnPlateau.

Am I missing something?

The code below is representative of how I have implemented the training and scheduling.

from functools import partial

import numpy as np
import pyro
import pyro.distributions as dist
import torch

#
from pyro import optim
from pyro.distributions.transforms import block_autoregressive, iterated
from pyro.infer import SVI
from pyro.infer.autoguide import AutoNormalizingFlow
from pyro.infer.trace_elbo import Trace_ELBO


def newmodel(mvn_distribution_specifications):
    """
    Wrapper Function to return a function
    """

    mean = mvn_distribution_specifications["mean"].astype(np.float32)
    covariance = mvn_distribution_specifications["covariance"].astype(np.float32)

    #
    mvn_distribution = dist.MultivariateNormal(
        loc=torch.tensor(mean),
        covariance_matrix=torch.tensor(covariance),
    )
    
    def model():
        pyro.sample(
            "model",
            mvn_distribution,
        )

    return model


def mvn_model(mvn_distribution_specifications):
    """
    Model function for the mcmc sampler
    """

    #
    return newmodel(mvn_distribution_specifications)

############
# Test specifications
mvn_distribution_specifications = {
    "mean": np.array(
        [60.0, 20.0],
        # dtype=np.float64
    ),
    "covariance": np.array(
        [[5.0, 0.0], [0.0, 5.0]],
        # dtype=np.float64
    ),
}

#
epochs = 10
initial_learning_rate = 5e-4
num_flows_map = 3
steps_per_epoch = 1000

#
model = mvn_model(mvn_distribution_specifications=mvn_distribution_specifications)

# set up first parts of map building
guide = AutoNormalizingFlow(
    model, partial(iterated, num_flows_map, block_autoregressive)
)

# 
scheduler = optim.ReduceLROnPlateau(
    {
	"optimizer": torch.optim.Adam,
	"optim_args": {"lr": initial_learning_rate},
	"factor": 0.1,
	"patience": 10,
	"threshold": 0.5,
    }
)

#
svi = SVI(model, guide, scheduler, Trace_ELBO())

#######
# Loop over epochs
for epoch_i in range(epochs):

    ######
    # 
    epoch_loss_array = []
    for local_training_i in range(steps_per_epoch):
        total_training_i = local_training_i + (epoch_i * steps_per_epoch)

        #############
        # Calculate loss
        loss = svi.step()

        #############
        #
        epoch_loss_array.append(loss)

    ##############
    #
    print("After training epoch {}".format(epoch_i))

    #
    mean_loss_array = np.mean(np.array(epoch_loss_array))
    print("current epoch mean_loss_array: ", mean_loss_array)
    scheduler.step(mean_loss_array)

    optimizer = list(scheduler.optim_objs.values())[0]
    print(optimizer.state_dict())
    print(optimizer.optimizer)
    print("==========")

Output for large patience where the loss change should trigger updating the
learning rate. Note: num_bad_epochs seems to indicate some change, but last_lr
doesn’t change.

epoch 0
current epoch mean_loss_array:  388.533140048258
{'factor': 0.1, 'min_lrs': [0], 'patience': 10, 'verbose': False, 'cooldown': 0, 'cooldown_counter': 0, 'mode': 'min', 'threshold': 0.5, 'threshold_mode': 'rel', 'best': 388.533140048258, 'num_bad_epochs': 0, 'mode_worse': inf, 'eps': 1e-08, 'last_epoch': 1, '_last_lr': [0.0005]}
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: False
    lr: 0.0005
    maximize: False
    weight_decay: 0
)
==========
epoch 1
current epoch mean_loss_array:  337.4650459958528
{'factor': 0.1, 'min_lrs': [0], 'patience': 10, 'verbose': False, 'cooldown': 0, 'cooldown_counter': 0, 'mode': 'min', 'threshold': 0.5, 'threshold_mode': 'rel', 'best': 388.533140048258, 'num_bad_epochs': 1, 'mode_worse': inf, 'eps': 1e-08, 'last_epoch': 2, '_last_lr': [0.0005]}
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: False
    lr: 0.0005
    maximize: False
    weight_decay: 0
)
==========
epoch 2
current epoch mean_loss_array:  268.6661277315496
{'factor': 0.1, 'min_lrs': [0], 'patience': 10, 'verbose': False, 'cooldown': 0, 'cooldown_counter': 0, 'mode': 'min', 'threshold': 0.5, 'threshold_mode': 'rel', 'best': 388.533140048258, 'num_bad_epochs': 2, 'mode_worse': inf, 'eps': 1e-08, 'last_epoch': 3, '_last_lr': [0.0005]}
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: False
    lr: 0.0005
    maximize: False
    weight_decay: 0
)
==========
epoch 3
current epoch mean_loss_array:  177.23352393411028
{'factor': 0.1, 'min_lrs': [0], 'patience': 10, 'verbose': False, 'cooldown': 0, 'cooldown_counter': 0, 'mode': 'min', 'threshold': 0.5, 'threshold_mode': 'rel', 'best': 177.23352393411028, 'num_bad_epochs': 0, 'mode_worse': inf, 'eps': 1e-08, 'last_epoch': 4, '_last_lr': [0.0005]}
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: False
    lr: 0.0005
    maximize: False
    weight_decay: 0
)
==========
epoch 4
current epoch mean_loss_array:  71.40807188195642
{'factor': 0.1, 'min_lrs': [0], 'patience': 10, 'verbose': False, 'cooldown': 0, 'cooldown_counter': 0, 'mode': 'min', 'threshold': 0.5, 'threshold_mode': 'rel', 'best': 71.40807188195642, 'num_bad_epochs': 0, 'mode_worse': inf, 'eps': 1e-08, 'last_epoch': 5, '_last_lr': [0.0005]}
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: False
    lr: 0.0005
    maximize: False
    weight_decay: 0
)
==========
epoch 5
current epoch mean_loss_array:  9.735887116171082
{'factor': 0.1, 'min_lrs': [0], 'patience': 10, 'verbose': False, 'cooldown': 0, 'cooldown_counter': 0, 'mode': 'min', 'threshold': 0.5, 'threshold_mode': 'rel', 'best': 9.735887116171082, 'num_bad_epochs': 0, 'mode_worse': inf, 'eps': 1e-08, 'last_epoch': 6, '_last_lr': [0.0005]}
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: False
    lr: 0.0005
    maximize: False
    weight_decay: 0
)
==========
epoch 6
current epoch mean_loss_array:  0.9936419904291406
{'factor': 0.1, 'min_lrs': [0], 'patience': 10, 'verbose': False, 'cooldown': 0, 'cooldown_counter': 0, 'mode': 'min', 'threshold': 0.5, 'threshold_mode': 'rel', 'best': 0.9936419904291406, 'num_bad_epochs': 0, 'mode_worse': inf, 'eps': 1e-08, 'last_epoch': 7, '_last_lr': [0.0005]}
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: False
    lr: 0.0005
    maximize: False
    weight_decay: 0
)

Output for small patience, where exceeding the patience always enforces the
update of the learning rate. These results are generated by setting patience to
0 in the code above. Note: the learning rate does change here.

After training epoch 0
current epoch mean_loss_array:  375.436277274251
{'factor': 0.1, 'min_lrs': [0], 'patience': 0, 'verbose': False, 'cooldown': 0, 'cooldown_counter': 0, 'mode': 'min', 'threshold': 0.5, 'threshold_mode': 'rel', 'best': 375.436277274251, 'num_bad_epochs': 0, 'mode_worse': inf, 'eps': 1e-08, 'last_epoch': 1, '_last_lr': [0.0005]}
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: False
    lr: 0.0005
    maximize: False
    weight_decay: 0
)
==========
After training epoch 1
current epoch mean_loss_array:  328.9502306429148
{'factor': 0.1, 'min_lrs': [0], 'patience': 0, 'verbose': False, 'cooldown': 0, 'cooldown_counter': 0, 'mode': 'min', 'threshold': 0.5, 'threshold_mode': 'rel', 'best': 375.436277274251, 'num_bad_epochs': 0, 'mode_worse': inf, 'eps': 1e-08, 'last_epoch': 2, '_last_lr': [5e-05]}
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: False
    lr: 5e-05
    maximize: False
    weight_decay: 0
)
==========
After training epoch 2
current epoch mean_loss_array:  298.6021881240606
{'factor': 0.1, 'min_lrs': [0], 'patience': 0, 'verbose': False, 'cooldown': 0, 'cooldown_counter': 0, 'mode': 'min', 'threshold': 0.5, 'threshold_mode': 'rel', 'best': 375.436277274251, 'num_bad_epochs': 0, 'mode_worse': inf, 'eps': 1e-08, 'last_epoch': 3, '_last_lr': [5e-06]}
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: False
    lr: 5e-06
    maximize: False
    weight_decay: 0
)
==========
After training epoch 3
current epoch mean_loss_array:  295.58652116286754
{'factor': 0.1, 'min_lrs': [0], 'patience': 0, 'verbose': False, 'cooldown': 0, 'cooldown_counter': 0, 'mode': 'min', 'threshold': 0.5, 'threshold_mode': 'rel', 'best': 375.436277274251, 'num_bad_epochs': 0, 'mode_worse': inf, 'eps': 1e-08, 'last_epoch': 4, '_last_lr': [5.000000000000001e-07]}
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: False
    lr: 5.000000000000001e-07
    maximize: False
    weight_decay: 0
)
==========
After training epoch 4
current epoch mean_loss_array:  295.3207841010094
{'factor': 0.1, 'min_lrs': [0], 'patience': 0, 'verbose': False, 'cooldown': 0, 'cooldown_counter': 0, 'mode': 'min', 'threshold': 0.5, 'threshold_mode': 'rel', 'best': 375.436277274251, 'num_bad_epochs': 0, 'mode_worse': inf, 'eps': 1e-08, 'last_epoch': 5, '_last_lr': [5.000000000000001e-08]}
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: False
    lr: 5.000000000000001e-08
    maximize: False
    weight_decay: 0
)
==========
After training epoch 5
current epoch mean_loss_array:  295.24778697776793
{'factor': 0.1, 'min_lrs': [0], 'patience': 0, 'verbose': False, 'cooldown': 0, 'cooldown_counter': 0, 'mode': 'min', 'threshold': 0.5, 'threshold_mode': 'rel', 'best': 375.436277274251, 'num_bad_epochs': 0, 'mode_worse': inf, 'eps': 1e-08, 'last_epoch': 6, '_last_lr': [5.000000000000002e-09]}
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: False
    lr: 5.000000000000002e-09
    maximize: False
    weight_decay: 0
)
==========
After training epoch 6
current epoch mean_loss_array:  295.28480722212794
{'factor': 0.1, 'min_lrs': [0], 'patience': 0, 'verbose': False, 'cooldown': 0, 'cooldown_counter': 0, 'mode': 'min', 'threshold': 0.5, 'threshold_mode': 'rel', 'best': 375.436277274251, 'num_bad_epochs': 0, 'mode_worse': inf, 'eps': 1e-08, 'last_epoch': 7, '_last_lr': [5.000000000000002e-09]}
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: False
    lr: 5.000000000000002e-09
    maximize: False
    weight_decay: 0
)
==========
After training epoch 7
current epoch mean_loss_array:  295.27312865316867
{'factor': 0.1, 'min_lrs': [0], 'patience': 0, 'verbose': False, 'cooldown': 0, 'cooldown_counter': 0, 'mode': 'min', 'threshold': 0.5, 'threshold_mode': 'rel', 'best': 375.436277274251, 'num_bad_epochs': 0, 'mode_worse': inf, 'eps': 1e-08, 'last_epoch': 8, '_last_lr': [5.000000000000002e-09]}
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: False
    lr: 5.000000000000002e-09
    maximize: False
    weight_decay: 0
)
==========
After training epoch 8
current epoch mean_loss_array:  295.25689352357386
{'factor': 0.1, 'min_lrs': [0], 'patience': 0, 'verbose': False, 'cooldown': 0, 'cooldown_counter': 0, 'mode': 'min', 'threshold': 0.5, 'threshold_mode': 'rel', 'best': 375.436277274251, 'num_bad_epochs': 0, 'mode_worse': inf, 'eps': 1e-08, 'last_epoch': 9, '_last_lr': [5.000000000000002e-09]}
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: False
    lr: 5.000000000000002e-09
    maximize: False
    weight_decay: 0
)
==========
After training epoch 9
current epoch mean_loss_array:  295.288202801466
{'factor': 0.1, 'min_lrs': [0], 'patience': 0, 'verbose': False, 'cooldown': 0, 'cooldown_counter': 0, 'mode': 'min', 'threshold': 0.5, 'threshold_mode': 'rel', 'best': 375.436277274251, 'num_bad_epochs': 0, 'mode_worse': inf, 'eps': 1e-08, 'last_epoch': 10, '_last_lr': [5.000000000000002e-09]}
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: False
    lr: 5.000000000000002e-09
    maximize: False
    weight_decay: 0
)
==========

Is this perhaps more of a pytorch related question? I still would like to resolve this issue.

don’t know what you’re seeing but you might try circumventing SVI and using differentiable_loss instead to see what kind of behavior you get

https://pyro.ai/examples/custom_objectives.html#A-Lower-Level-Pattern