I've encountered this error multiple times, but it is non-deterministic and only occurs occasionally.
Traceback (most recent call last):
File "train.py", line 33, in <module>
_, loss_dict = model.train(*data)
File "/home/luoa/junting/video_prediction/models/air_base_model.py", line 389, in train
loss = svi.loss_and_grads(svi.model, svi.guide, input, output)
File "/home/luoa/anaconda3/lib/python3.6/site-packages/pyro/infer/elbo.py", line 65, in loss_and_grads
return self.which_elbo.loss_and_grads(model, guide, *args, **kwargs)
File "/home/luoa/anaconda3/lib/python3.6/site-packages/pyro/infer/trace_elbo.py", line 133, in loss_and_grads
for weight, model_trace, guide_trace, log_r in self._get_traces(model, guide, *args, **kwargs):
File "/home/luoa/anaconda3/lib/python3.6/site-packages/pyro/infer/trace_elbo.py", line 87, in _get_traces
log_r = model_trace.log_pdf() - guide_trace.log_pdf()
File "/home/luoa/anaconda3/lib/python3.6/site-packages/pyro/poutine/trace.py", line 71, in log_pdf
site["value"], *args, **kwargs) * site["scale"]
File "/home/luoa/anaconda3/lib/python3.6/site-packages/pyro/distributions/random_primitive.py", line 42, in log_pdf
return self.dist_class(*args, **kwargs).log_pdf(x)
File "/home/luoa/anaconda3/lib/python3.6/site-packages/pyro/distributions/distribution.py", line 185, in log_pdf
return torch.sum(self.batch_log_pdf(x, *args, **kwargs))
RuntimeError: value cannot be converted to type double without overflow: -inf
I don't know how to debug this, since sometimes when I re-run the same thing, the error might not occur. I'm using Adam optimizer with learning rate 2e-4 (or 1e-4).
The only thing that I can think of is that I use
softplus to get standard deviation for sampling. I have code that looks like this:
x = self.model(...)
x_mu = x[:, :N]
x_sigma = F.softplus(x[:, N:])
x = pyro.sample('name', dist.normal, x_mu, x_sigma)
softplus is the standard way of doing this though. The VAE tutorial an AIR tutorial both use
softplus to get the standard deviation. So this might not be the problem.
Any suggestions on how to debug this?