Hi, I have the following script (adapted from test code in the repo).
import torch
from pyro.distributions import MixtureOfDiagNormals
torch.autograd.set_detect_anomaly(True)
K=3
D=50
locs = torch.rand(K, D).requires_grad_(True)
coord_scale = torch.ones(K, D) + 0.5 * torch.rand(K, D)
coord_scale = coord_scale.requires_grad_(True)
component_logits = (1.5 * torch.rand(K)).requires_grad_(True)
n_samples = 200000
sample_shape = torch.Size((n_samples,))
dist = MixtureOfDiagNormals(locs=locs, coord_scale=coord_scale, component_logits=component_logits)
z = dist.rsample(sample_shape=sample_shape)
cost = torch.pow(z, 2.0).sum() / float(n_samples)
cost.backward()
Execution results in the exception:
Function '_MixDiagNormalSampleBackward' returned nan values in its 0th output.
If I set D=4
(as in the test code), it works.
Python version: 3.9.7
Pytorch version: 1.12.1+cu116
Pyro version: 1.8.3