Hey,
I am trying to implement the model from this paper in pyro. I have a working pymc3 model (code at the bottom). My pyro code is below. Without the jit compiler the code seems to work but is very slow, 2 samples a second on my machine. With the jit compiler turned on, there are a bunch of warnings and the model does not converge. Any help would be most appreciated. Thanks.
Warnings
0/2000 [00:00, ?it/s]/Users/ryany/opt/anaconda3/envs/pryo2/lib/python3.7/site-packages/ipykernel_launcher.py:28: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can’t record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
/Users/ryany/opt/anaconda3/envs/pryo2/lib/python3.7/site-packages/ipykernel_launcher.py:33: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
/Users/ryany/opt/anaconda3/envs/pryo2/lib/python3.7/site-packages/ipykernel_launcher.py:33: TracerWarning: Converting a tensor to a Python float might cause the trace to be incorrect. We can’t record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
/Users/ryany/opt/anaconda3/envs/pryo2/lib/python3.7/site-packages/ipykernel_launcher.py:34: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
/Users/ryany/opt/anaconda3/envs/pryo2/lib/python3.7/site-packages/ipykernel_launcher.py:34: TracerWarning: Converting a tensor to a Python float might cause the trace to be incorrect. We can’t record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
Pryo code:
import torch
import pyro
import pyro.distributions as dist
import pyro.infer.mcmc as mcmc
from pyro.infer import MCMC, NUTS
dat = torch.tensor([[100. , 63.1, 46.8, 38.2, 32.6, 28.9, 26.2, 24.1],
[-1, 36.9, 16.300000000000004, 8.599999999999994, 5.600000000000001,
3.700000000000003, 2.6999999999999993, 2.099999999999998]])
class sBG(pyro.distributions.Distribution):
def __init__(self, alpha, beta):
self.alpha = alpha
self.beta = beta
def sample(self):
return torch.tensor(1.)
def log_prob(self, data):
active = data[0,:]
lost = data[1,:]
n = active.shape[0]
p = [0., self.alpha / (self.alpha + self.beta)]
s = [0., 1 - p[1]]
for t in range(2, n):
pt = ((self.beta + t - 2) / (self.alpha + self.beta + t - 1)) * p[t-1]
p.append(pt)
s.append(s[t-1] - p[t])
p = torch.tensor(p)
s = torch.tensor(s)
died = torch.mul(torch.log(p[1:]), lost[1:])
still_active = torch.log(s[-1]) * active[-1]
logp = torch.sum(died) + still_active
return logp
def model(data):
alpha = pyro.sample('alpha', dist.Uniform(0, 10))
beta = pyro.sample('beta' , dist.Uniform(0, 10))
return pyro.sample('p', sBG(alpha, beta), obs=data)
nuts_kernel = NUTS(model, jit_compile=False)
mcmc= MCMC(nuts_kernel, num_samples=100, warmup_steps=10)
mcmc.run(dat)
Pymc3 code (works):
def sBG_model(data):
num_treatments = data['active'].shape[0]
n = len(data['active'])
with pm.Model() as model:
# uniformative priors
alpha = pm.Uniform('alpha', 0.00001, 10.0, testval=1)
beta = pm.Uniform('beta', 0.00001, 10.0, testval=1)
## defined in equation 7 in the paper
p = [0., alpha / (alpha + beta)]
s = [0., 1 - p[1]]
for t in range(2, n):
pt = ((beta + t - 2) / (alpha + beta + t - 1)) * p[t-1]
p.append(pt)
s.append(s[t-1] - p[t])
# theano type conversion
p = tt.stack(p)
s = tt.stack(s)
def logp(active, lost):
# Those who've churned along the way...
died = tt.mul(tt.log(p[1:]), lost[1:])
# and those still active in last period
still_active = tt.log(s[-1]) * active[-1]
return tt.sum(died) + still_active
retention = pm.DensityDist('retention', logp, observed=data)
return model