I also often see NANs when using a parametrization Poisson(my_param.exp())
. Two numerical tricks @martinjankowiak and I use are either
- replace
torch.exp()
withtorch.softplus
, or - replace
torch.exp()
with an appropriately scaledtorch.softmax()
. I often use the following:
def bounded_exp(x, bound):
return (x - math.log(bound)).sigmoid() * bound
and you can pick a reasonable upper bound
by looking at your data.