Hi,
I just found this tutorial on how Pyro can be used for predicting time-to-event. I wanted to use it to predict time to events using a dataset that has recurrent events with some features associated with them (e.g., purchases a user made).
Here’s a reproducible toy example:
# some feature
x11 = [4.6151, 0.0000, 4.6151, 0.0000, 0.0000, 0.0000, 0.0000, 4.6151, 4.6151]
#1 if event not censored, 0 otherwise
y11 = [1.,1.,1.,1.,1.,1.,1.,1.,1.]
# actual time to event
y22 = [1., 0., 4., 3., 2., 1., 0., 0., 0.]
def model():
a_model = pyro.sample("a_model", dist.Normal(0, 10)) ## Note [2]
b_model = pyro.sample("b_model", dist.Normal(0, 10))
x1 = torch.from_numpy(np.array(x11))
link = torch.nn.functional.softplus(a_model * x1 + b_model) ## Note [3]
truncation_label = torch.from_numpy(np.array(y11))
y2 = torch.from_numpy(np.array(y22))
for i in range(len(x1)):
y_hidden_dist = dist.Exponential(1 / link[i]) ## Note [4]
# it this event labeled as censored?
if truncation_label[i] == 1:
y_real = pyro.sample("obs_{}".format(i),
y_hidden_dist,
obs = y2[i])
else:
truncation_prob = 1 - y_hidden_dist.cdf(y2[i])
pyro.sample("truncation_label_{}".format(i),
dist.Bernoulli(truncation_prob),
obs = truncation_label[i])
pyro.clear_param_store()
hmc_kernel = HMC(model,
step_size = 0.1,
num_steps = 4)
mcmc_run = MCMC(hmc_kernel,
num_samples=5,
warmup_steps=1).run()
marginal_a = EmpiricalMarginal(mcmc_run,
sites="a_model")
posterior_a = [marginal_a.sample() for i in range(50)]
sns.distplot(posterior_a)
But this crashes with the error in the topic title. What am I doing wrong?