ClippedAdam Gradient Explosion

    def model(self, mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lens, annealing_factor=1.0):
        T = mini_batch.size(1)
        
        pyro.module('dmm', self)

        z_prev = self.z_0.expand(mini_batch.size(0), self.z_0.size(0))
        with pyro.plate("z_minibatch", len(mini_batch)):
            for t in pyro.markov(range(1, T + 1)):
                z_loc, z_scale = self.transitor(z_prev)

                with poutine.scale(scale=annealing_factor):
                    z_dist = dist.Normal(
                        loc=z_loc, 
                        scale=z_scale
                    ).mask(
                        mini_batch_mask[:, t-1:t]
                    ).to_event(1)

                    z_t = pyro.sample(
                        name=f'z_{t}',
                        fn=z_dist
                    )

                # compute the emission -> probability of win/loss
                emission_t = self.emitter(z_t) 

                # sample the observation
                pyro.sample(
                    name=f'obs_x_{t}', 
                    fn=dist.Bernoulli(
                        probs=emission_t
                    ).mask(
                        mini_batch_mask[:, t-1:t]
                    ).to_event(1),
                    obs=mini_batch[:, t-1, 0]
                )

                z_prev = z_t

    def guide(self, mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lens, annealing_factor=1.0):
        T = mini_batch.size(1)
        
        pyro.module('dmm', self)

        h_0_contig = self.h_0.expand(1, mini_batch.size(0), self.rnn.hidden_size).contiguous().to(mini_batch.device)

        rnn_output, _ = self.rnn(mini_batch_reversed, h_0_contig)
        rnn_output = pad_and_reverse(rnn_output, mini_batch_seq_lens)
        rnn_output = self.layer_norm(rnn_output)

        z_prev = self.z_q_0.expand(mini_batch.size(0), self.z_q_0.size(0))

        with pyro.plate("z_minibatch", len(mini_batch)):
            for t in pyro.markov(range(1, T + 1)):
                z_loc, z_scale = self.combiner(z_prev, rnn_output[:, t-1, :])

                with pyro.poutine.scale(scale=annealing_factor):
                    z_dist = dist.Normal(
                        loc=z_loc, 
                        scale=z_scale
                    ).mask(
                        mini_batch_mask[:, t-1:t]
                    ).to_event(1)

                    z_t = pyro.sample(
                        name=f'z_{t}',
                        fn=z_dist
                    )

                    z_prev = z_t

This code is the main process in my model. It is a custom DMM model. I am using pyro.optim.ClippedAdam and have tried clip_norm = 0.00001, 1.0, 10 and a bunch of values in between, but the gradients are always regardless and don’t appear to change when I change the clip_norm. Below is my model initialization and the config.yaml I am using.

device = torch_device('cpu')

# instantiate the dmm
dmm = DMM(
    input_dim=config['model_params']['input_dim'], 
    latent_dim=config['model_params']['latent_dim'],
    hidden_dim=config['model_params']['hidden_dim'], 
    emission_dim=config['model_params']['emission_dim'],
    transition_dim=config['model_params']['transition_dim'], 
    rnn_dim=config['model_params']['rnn_dim'],
    rnn_dropout_rate=config['model_params']['rnn_dropout_rate'],
).to(device)

adam = ClippedAdam(config['optimization_params'])

elbo = TraceEnum_ELBO()
dmm_guide = config_enumerate(
    dmm.guide,
    default="parallel",
    num_samples=config['evaluation_params']['tmc_num_samples'],
    expand=False
)
svi = SVI(dmm.model, dmm_guide, adam, loss=elbo)
# config.yaml
model_params:
  input_dim: 51
  latent_dim: 64
  hidden_dim: 16
  emission_dim: 1
  transition_dim: 128
  rnn_dim: 256
  rnn_dropout_rate: 0.0
training_params:
  num_epochs: 10
  mini_batch_size: 128
  annealing_epochs: 100
  minimum_annealing_factor: 0.2
  checkpoint_freq: 0
  test_frequency: 1
  cuda: True
optimization_params:
  lr: 0.00003
  betas: [0.96, 0.999]
  clip_norm: 1.0
  lrd: 0.99996
  weight_decay: 1.0
evaluation_params:
  tmc_num_samples: 10

what exactly are you saying? clipping happens at the level of the optimizer and doesn’t do in-place operations on .grad attributes

@martinjankowiak
I see I understand now why they seem to be large regardless of the level of clipping I do. I guess I would like advice in actually reducing the effective size of the gradients, as they continue to grow. I now understand that at the level of the optimizer they will be normalized, but if the actual .grad attributes continue to grow this is worrisome too, no?

sounds troubling. do you see the same behavior if you run the tutorial?

I had not checked but now that I did, I find that the gradient norm does increase over time and get quite large > 1e3. I should note that this is only being checked over the first 20 mini_batches.

1e3 doesn’t necessarily sound crazy scare to me

That is in the tutorial where the dataset size is significantly smaller than in my current example unfortunately. It is hard to compare since not only is the amount of data I am using larger, but so is the learning objective for the model and the size of the hidden dimensions. But for reference, in my implementation:

Mini Batch 1 Gradient Norm Info:
min: ~40
max: ~17,000

Mini Batch 300 Gradient Norm Info:
min: ~12,000
max: ~5,000,000

Although it’s likely that it survives for a couple of epochs because of the optimizer clipping, it seems to inevitably leads to a ValueError where my weight matrix is full of nan values.

rnns are notoriously tricky to train, especially for long sequence lengths. you may need to detach rnn intermediates to get more stable training

I really appreciate your help so far. Just to see what would happen I wrapped the rnn with torch.no_grad() - I know this is not what you were referencing with the link you sent - but I still find the gradient explodes. Is it possible there is a fundamental problem with the emission. With or without the rnn it is the highest in grad_norm.

anything is possible. unfortunately i don’t have the time to debug your neural networks for you… but i suspect doing various kinds of gradient surgery like you’re doing well help you isolate the issue which is probably a deep learning issue and not a pyro issue as such

1 Like

In any case, I appreciate the time you’ve taken to help.