I’m trying to implement right truncated poisson regression
http://www.ce.memphis.edu/7906/2014Fall/Lecture-19_v1.pdf
page 28 is the probability that I’m trying to implement.
I’m trying to implement it by changing some of categorical distribution.
And it looks like it works. But SVI or MCMC takes too long so almost impossible to run.
This is my custom distribution
class rtPoisson(torch.distributions.Categorical, TorchDistributionMixin):
def __init__(self, rate, trunc, validate_args=None):
self.rate = broadcast_all(rate)
self.trunc = trunc
if isinstance(self.rate, tuple):
self.rate = self.rate[0]
self.pois = dist.Poisson(self.rate)
inp = torch.ones(self.rate.shape[0],9) * torch.arange(9,dtype = torch.float)
temp_cal = self.pois.log_prob(inp.T).exp()
self.const = torch.unsqueeze(torch.sum(temp_cal, axis = 0),axis = 1)
self.probs = temp_cal.T / self.const
super(rtPoisson, self).__init__(probs = self.probs)
def expand(self, batch_shape):
try:
return super(rtPoisson, self).expand(batch_shape)
except NotImplementedError:
validate_args = self.__dict__.get('_validate_args')
rate = self.rate.expand(batch_shape)
return type(self)(rate, self.trunc, validate_args=validate_args)
def log_prob(self, value):
log_pro = self.pois.log_prob(value) - self.const.log()
return log_pro
def enumerate_support(self, expand=True):
result = super(Categorical, self).enumerate_support(expand=expand)
if not expand:
result._pyro_categorical_support = id(self)
return result
And this is model and svi running code
class rtPois:
def __init__(self, trunc):
self.trunc = trunc
def model(self, data, ratings):
with pyro.plate("betas", data.shape[1]):
betas = pyro.sample("beta", dist.Gamma(1,1))
lambda_ = torch.sum(betas * data,axis=1)
with pyro.plate("ratings", data.shape[0]):
y = pyro.sample("obs", rtPoisson(lambda_, self.trunc), obs = ratings)
return y
I tried identity linking with gamma prior. I tried normal prior and exp linking which gave me inf.
data = load('data_pickle/data')
ratings = load('data_pickle/ratings')
rest_data = load('data_pickle/rest_data')
rtpois = rtPois()
from pyro.infer.autoguide import AutoDiagonalNormal
guide = AutoDiagonalNormal(rtpois.model)
#Poisson
rtpois_ratings = rtpois.process_ratings(ratings)
svi_model = SVI(rtpois.model,
guide,
optim.Adam({"lr": .005}),
loss=JitTraceEnum_ELBO(),
num_samples=500)
pyro.clear_param_store()
loss_list = []
for i in range(10000):
ELBO = svi_model.step(data, rtpois_ratings, torch.tensor(8.))
if i % 500 == 0:
print(ELBO)
loss_list.append(ELBO)
This is the SVI code.
My data is composed fo around 200 features with data size around 10,000
Could you let me know why is it so slow(or not working)?