Truncated poisson SVI takes too long

I also often see NANs when using a parametrization Poisson(my_param.exp()). Two numerical tricks @martinjankowiak and I use are either

  1. replace torch.exp() with torch.softplus, or
  2. replace torch.exp() with an appropriately scaled torch.softmax(). I often use the following:
def bounded_exp(x, bound):
    return (x - math.log(bound)).sigmoid() * bound

and you can pick a reasonable upper bound by looking at your data.