SVI: nans from guide when Trace_ELBO drops

During training, I’m getting nans when the loss drops.

lr=1e-6
svi = pyro.infer.SVI(model=model, 
                     guide=guide, 
                     optim=pyro.optim.Adam({"lr": lr}), 
                     loss=pyro.infer.Trace_ELBO())

This happens consistently, even when the step size is low enough that training is quite stable.

Attached is a plot. The prediction is quite good after the loss drops, so “things are working”.

  1. What causes the nans?
  2. How can they be avoided?
  3. How can I save the guide (neural network) params so that I can “restart” training if I encouter nans?

Here is a picture of the loss dropping. There’s an averaging window of 1000 steps with pandas.Series(loss).rolling(1000), so it’s much more stochastic than it looks.
download

This is the error

/usr/local/lib/python3.7/dist-packages/pyro/poutine/trace_struct.py:286: UserWarning: Encountered NaN: log_prob_sum at site 'quaternions'
  site["log_prob_sum"], "log_prob_sum at site '{}'".format(name)
/usr/local/lib/python3.7/dist-packages/pyro/infer/trace_elbo.py:158: UserWarning: Encountered NaN: loss
  warn_if_nan(loss, "loss")
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
    173             try:
--> 174                 ret = self.fn(*args, **kwargs)
    175             except (ValueError, RuntimeError) as e:

19 frames
ValueError: Expected parameter concentration (Tensor of shape (500, 2, 4)) of distribution ProjectedNormal(concentration: torch.Size([500, 2, 4])) to satisfy the constraint IndependentConstraint(Real(), 1), but found invalid values:
tensor([[[nan, nan, nan, nan],
         [nan, nan, nan, nan]],

        [[nan, nan, nan, nan],
         [nan, nan, nan, nan]],

        [[nan, nan, nan, nan],
         [nan, nan, nan, nan]],

        ...,

        [[nan, nan, nan, nan],
         [nan, nan, nan, nan]],

        [[nan, nan, nan, nan],
         [nan, nan, nan, nan]],

        [[nan, nan, nan, nan],
         [nan, nan, nan, nan]]], device='cuda:0', grad_fn=<MulBackward0>)

The above exception was the direct cause of the following exception:

ValueError                                Traceback (most recent call last)
<decorator-gen-53> in time(self, line, cell, local_ns)

<timed exec> in <module>()

/usr/local/lib/python3.7/dist-packages/torch/distributions/distribution.py in __init__(self, batch_shape, event_shape, validate_args)
     54                 if not valid.all():
     55                     raise ValueError(
---> 56                         f"Expected parameter {param} "
     57                         f"({type(value).__name__} of shape {tuple(value.shape)}) "
     58                         f"of distribution {repr(self)} "

ValueError: Expected parameter concentration (Tensor of shape (500, 2, 4)) of distribution ProjectedNormal(concentration: torch.Size([500, 2, 4])) to satisfy the constraint IndependentConstraint(Real(), 1), but found invalid values:
tensor([[[nan, nan, nan, nan],
         [nan, nan, nan, nan]],

        [[nan, nan, nan, nan],
         [nan, nan, nan, nan]],

        [[nan, nan, nan, nan],
         [nan, nan, nan, nan]],

        ...,

        [[nan, nan, nan, nan],
         [nan, nan, nan, nan]],

        [[nan, nan, nan, nan],
         [nan, nan, nan, nan]],

        [[nan, nan, nan, nan],
         [nan, nan, nan, nan]]], device='cuda:0', grad_fn=<MulBackward0>)
                   Trace Shapes:                 
                    Param Sites:                 
   net_rot$$$cnn_layers.0.weight  32   1   3    3
     net_rot$$$cnn_layers.0.bias               32
   net_rot$$$cnn_layers.2.weight  32  32   3    3
     net_rot$$$cnn_layers.2.bias               32
   net_rot$$$cnn_layers.5.weight  64  32   3    3
     net_rot$$$cnn_layers.5.bias               64
   net_rot$$$cnn_layers.7.weight  64  64   3    3
     net_rot$$$cnn_layers.7.bias               64
  net_rot$$$cnn_layers.10.weight 128  64   3    3
    net_rot$$$cnn_layers.10.bias              128
  net_rot$$$cnn_layers.12.weight 128 128   3    3
    net_rot$$$cnn_layers.12.bias              128
net_rot$$$linear_layers.0.weight         512 2048
  net_rot$$$linear_layers.0.bias              512
net_rot$$$linear_layers.2.weight         512  512
  net_rot$$$linear_layers.2.bias              512
net_rot$$$linear_layers.4.weight          12  512
  net_rot$$$linear_layers.4.bias               12
                   Sample Sites:                 
                 mini_batch dist                |
                           value         500    |
                        dfs dist         500    |
                           value         500    |

Is there any documentation on model/guide persistence? I’d like to save the model every so many iterations in case this happens as a restart.

you might want to e.g. clip the input parameters to ProjectedNormal so that they don’t hit dangerous extremes

Hmm, how exactly should I do that?

The guide predicts the a normalized 4 vector and a concentration, that then go into the ProjectedNormal. I don’t see how I could clip the 4-vector, since it all values are valid. The prior is uniform, or nearly uniform. Should I clip it to be near the prior, by somehow taking the average on the 4 sphere between them?

As for the concentration, it’s positive. I’m exp-ing a real values output, so perhaps its getting too small or too big… Maybe what is happening is that the concentration is getting too large… so I should clip it at some max value such that the distribution only gets so tight.

But thanks @martinjankowiak this gives some some things to check :slight_smile:

Also, I lowered the learning rate from 1e-6 to 1e-7 (Adam), when the loss was dropping, but it still dropped and returned nans.

I changed it around 30000 where the inflection happens.
download-1

well your error message suggestions the concentration is the problem. you might do e.g.

log_conc = pyro.param("log_conc", constraints.interval(-5, 5))
conc = log_conc.exp()

you might also try using 64-bit precision

log_conc is coming out of a neural network. I am doing pyro.module and registering all params of the neural network.

Should I also do pyro.param on the log_conc, or just clip it?

It’s not a parameter, but one of the distributional parameters for a latent space variable in z. It’s are characterizing the posterior q. I use it to make the distribution from which I sample z. See SVI Part III: ELBO Gradient Estimators — Pyro Tutorials 1.8.1 documentation

oh if it comes out of a neural network then you can do e.g.

log_conc = log_conc.clamp(min=-5, max=5)

no need to involve param

2 Likes

Unfortunately, even when clipping the concentration returns (and other things that get exp-ed) from the neural net, I am still getting nans. I tried .clamp(min=-5, max=5) and .clamp(min=-4, max=4). The neural net returns NaNs in every element of the tensor. Right before this, the some/many of the neural net are NaNs.

I have not seen this issue with SGD (but I haven’t been able to train with SGD).

I tried ClippedAdam. Do you think I should experiment with optimizer params? For example: pyro.optim.ClippedAdam({"lr": 1e-6,'eps':1e-6})?

maybe. unfortunately i know essentially nothing about your model. usually in cases like this the numerical issues can be traced to one or two specific places. so i suggest you log as much information you can during the course of training to see what’s happening just before you get nans. e.g. compute mins/maxes/medians of different parameters, not just log_conc.

you can also try to do checkpointing as described e.g. here and then once your training fails reload at the last break point and do additional training/diagnosing with pdb or what have you

Just a guess, but I sometimes see NANs and loss diverging to -inf when some distribution’s scale parameter converges to zero because there is only one datapoint or because a downstream distribution collapses. As Martin mentioned, it would help us diagnose if you could share your model.

I looked into this more, and many signs point to it being a numerical stability issue with _log_prob_4 in pyro.distributions.projected_normal — Pyro documentation .

Here is a gist with some results / numbers: Google Colab

The para_part is (pos_piece + other_piece).log() and sometimes the other_piece can be negative which returns a nan. Should it be -inf?

I’m happy to help / do work for a more numerically stable implementation.

The integral that was done in Mathematica is like the third moment of Normal, except that the integral is on the positive Real line.

From _log_prob_4 docs:

# This is the log of a definite integral, computed by mathematica:
# Integrate[x^3/(E^((x-t)^2/2) Sqrt[2 Pi]), {x, 0, Infinity}]
# = (2 + t^2)/(E^(t^2/2) Sqrt[2 Pi]) + (t (3 + t^2) (1 + Erf[t/Sqrt[2]]))/2

So no matter what t is, negative or positive, the para_part should never be negative because x^3 and E^((x-t)^2/2) is never negative in the domain (x, 0, Infinity). t just moves the Gaussian, which will keep on being non negative.

1 Like

What’s needed is asymptotic as t goes to a “large” negative number… like -4.1… (and I suspect not that large and frequency encountered!). I suspect that this integral, which is the para_part up to a positive constant and is giving nans can be shown to go to -inf as t goes more negative…

So what should be done?

  • truncate the log_prob of bad samples to some maximal large negative number
  • can the -inf be handled by Pyro / the Pytorch optimizer?

Some details on the model and guide.

I did clamp the log_concentration after it’s returned from the NN, but I have to clamp it to be quite spread out (~4.13) and not concentrated enough which prevents the posterior from being peaked enough. See “Figure 2” which shows samples projected on the sphere and how spread out they are.




perhaps the erf is the problem? i guess you could try decomposing the log_prob in terms of a “small t” and “large t” expression where the latter uses an asymptotic expansion for erf(t) like

image

:+1: @geoffwoollard this is great, we’d love to have a more numerically stable implementation of ProjectedNormal.log_prob(). My initial implementation is both very naive (directly copied from Wolfram alpha) and not general (must be implemented for each dimension). I think this is a really nice distribution and I’m happy to help out with code review and testing of any numerical stability work you do.

@geoffwoollard Could you see if this PR resolves your issue? It simply replaces the (x * y).log() with torch.logaddexp(x.log(), y.log()) in a few places.

EDIT hmm I think we’ll need to be smarter…

Ok great. I’m looking into it @fritzo . @martinjankowiak yes, that seems a good idea about the erf.

The piece that seems to be unstable is the erf (goes small) times the cubic piece in t (goes big). The implementation of logsumexp in scipy has a scaling term b: scipy/_logsumexp.py at v1.8.0 · scipy/scipy · GitHub

So maybe a custom version of logaddexp.

NB: the orange is flipped after the cusp at (0,-5) since I took the abs of the t3 term to be able to log it.
download

Great, I just updated this PR to replace erf with erfc and clamp the logs. That seems to handle numerical stability. @geoffwoollard would you mind testing that branch?