Detecting NaN During Adam Optimization

Hello,

I have a model where I am optimizing parameters using the AutoDelta Guide. Initially the training is fine, with loss decreasing as expected. However, after a while the likelihood / loss returns “nan”. I have been inspecting possible parameters that would result in a “nan” log prob, but it is time consuming to search every suspected parameter ( large model). Is there a way to detect in which variable / parameter a “nan” log prob occurs automatically? It could occur in the log prior for parameters as well, which would be harder to detect.

Rather than try to chase down all the parameters, in my experience the loss can go to NaN if the variance of the loss is too large. To reduce that variance you can do one or more of these things, that will all increase time for the model to fit:

  • Reduce Adam step size
  • Increase the number of particles
  • If you’re using minibatching, increase the batch size

I’m not aware of any tools in numpyro to chase down specific problem parameters, but the devs would probably be more helpful with that than me. You can put in checks to make sure parameters are in the right range. Without seeing your model it’s difficult to be specific, but there are clip functions in jax IIRC?

I figured out the specific issue in my case was that the masking of observations in the likelihood was multiplying 0 * nan = nan, resulting in NaN. The fix was to mask before passing the values to the distribution, but that still doesn’t fix the overall issue of likelihood values going to NaN, but did help my specific problem case. I figure my parameterization / priors are not sufficient enough to prevent degenerate likelihoods and this is more of a model specification issue