NaN grad with MixtureOfDiagNormals

Hi, I have the following script (adapted from test code in the repo).

import torch
from pyro.distributions import MixtureOfDiagNormals



locs = torch.rand(K, D).requires_grad_(True)
coord_scale = torch.ones(K, D) + 0.5 * torch.rand(K, D)
coord_scale = coord_scale.requires_grad_(True)
component_logits = (1.5 * torch.rand(K)).requires_grad_(True)

n_samples = 200000
sample_shape = torch.Size((n_samples,))
dist = MixtureOfDiagNormals(locs=locs, coord_scale=coord_scale, component_logits=component_logits)
z = dist.rsample(sample_shape=sample_shape)
cost = torch.pow(z, 2.0).sum() / float(n_samples)

Execution results in the exception:
Function '_MixDiagNormalSampleBackward' returned nan values in its 0th output.

If I set D=4 (as in the test code), it works.

Python version: 3.9.7
Pytorch version: 1.12.1+cu116
Pyro version: 1.8.3

can you isolate where the nan is? is it component_logits or elsewhere?

locs.grad, coord_scale.grad and component_logits.grad all return nans

I managed to resolve the issue with torch.set_default_dtype(torch.float64) which suggests the issue is numerical.

ah yes sticking to double precision is probably a good idea for these distributions. the computation of the backward pass is quite complicated and numerically a bit delicate

1 Like