Truncated poisson SVI takes too long

I’m trying to implement right truncated poisson regression
page 28 is the probability that I’m trying to implement.

I’m trying to implement it by changing some of categorical distribution.
And it looks like it works. But SVI or MCMC takes too long so almost impossible to run.

This is my custom distribution

class rtPoisson(torch.distributions.Categorical, TorchDistributionMixin):
    def __init__(self, rate, trunc, validate_args=None):
        self.rate = broadcast_all(rate)

        self.trunc = trunc
        if isinstance(self.rate, tuple):
            self.rate = self.rate[0]        
        self.pois = dist.Poisson(self.rate)
        inp = torch.ones(self.rate.shape[0],9) * torch.arange(9,dtype = torch.float)
        temp_cal = self.pois.log_prob(inp.T).exp()
        self.const = torch.unsqueeze(torch.sum(temp_cal, axis = 0),axis = 1)

        self.probs = temp_cal.T / self.const
        super(rtPoisson, self).__init__(probs = self.probs)

    def expand(self, batch_shape):
            return super(rtPoisson, self).expand(batch_shape)
        except NotImplementedError:
            validate_args = self.__dict__.get('_validate_args')
            rate = self.rate.expand(batch_shape)
            return type(self)(rate, self.trunc, validate_args=validate_args)
    def log_prob(self, value):
        log_pro = self.pois.log_prob(value) - self.const.log()
        return log_pro

    def enumerate_support(self, expand=True):
        result = super(Categorical, self).enumerate_support(expand=expand)
        if not expand:
            result._pyro_categorical_support = id(self)
        return result

And this is model and svi running code

class rtPois:
    def __init__(self, trunc):
        self.trunc = trunc
    def model(self, data, ratings):

        with pyro.plate("betas", data.shape[1]):
            betas = pyro.sample("beta", dist.Gamma(1,1))
        lambda_ = torch.sum(betas * data,axis=1)
        with pyro.plate("ratings", data.shape[0]):
            y = pyro.sample("obs", rtPoisson(lambda_, self.trunc), obs = ratings)
        return y

I tried identity linking with gamma prior. I tried normal prior and exp linking which gave me inf.

data = load('data_pickle/data')
ratings = load('data_pickle/ratings')
rest_data = load('data_pickle/rest_data')
rtpois = rtPois()
from pyro.infer.autoguide import AutoDiagonalNormal
guide = AutoDiagonalNormal(rtpois.model)
rtpois_ratings = rtpois.process_ratings(ratings)
svi_model = SVI(rtpois.model, 
                optim.Adam({"lr": .005}), 

loss_list = []
for i in range(10000):
    ELBO = svi_model.step(data, rtpois_ratings, torch.tensor(8.))
    if i % 500 == 0:

This is the SVI code.
My data is composed fo around 200 features with data size around 10,000

Could you let me know why is it so slow(or not working)?

What value of trunc are you using?

I would replace this with a torch.matmul, something like

lambda_ = data.matmul(betas.unsqueeze(-1)).squeeze(-1)

Also, I don’t see any enumeration in this model, so JitTrace_ELBO will be a little faster.

Im using 9 for this

Also, I think if I use GPU, then it becomes faster.

But now I have another problem. loss keep becomes NaN or Inf
I think it is because of the exponential part.
this is my log_normal exponential linking code

def model(self, data, ratings, trunc):

        with pyro.plate("betas", data.shape[1]):
            betas = pyro.sample("beta", dist.Normal(0,1))
#         print(torch.sum(betas*data,axis = 1).max())

        lambda_ = torch.matmul(data, betas).exp()
        with pyro.plate("ratings", data.shape[0]):
            y = pyro.sample("obs", rtPoisson(lambda_, trunc), obs = ratings)
        return y

So max of lambda_ is around exp(20). exp(20) itself is okay, but when I calculate categorical probability, I need to calculate lambda^10 which is Inf.

Do you know of any ways to workaround this?

I also often see NANs when using a parametrization Poisson(my_param.exp()). Two numerical tricks @martinjankowiak and I use are either

  1. replace torch.exp() with torch.softplus, or
  2. replace torch.exp() with an appropriately scaled torch.softmax(). I often use the following:
def bounded_exp(x, bound):
    return (x - math.log(bound)).sigmoid() * bound

and you can pick a reasonable upper bound by looking at your data.