Modeling time to event: get error "expected type torch.FloatTensor but got torch.DoubleTensor"

Hi,

I just found this tutorial on how Pyro can be used for predicting time-to-event. I wanted to use it to predict time to events using a dataset that has recurrent events with some features associated with them (e.g., purchases a user made).

Here’s a reproducible toy example:

# some feature
x11 = [4.6151, 0.0000, 4.6151, 0.0000, 0.0000, 0.0000, 0.0000, 4.6151, 4.6151]
#1 if event not censored, 0 otherwise
y11 = [1.,1.,1.,1.,1.,1.,1.,1.,1.]
# actual time to event 
y22 = [1.,  0.,  4.,  3.,  2.,  1.,  0.,  0.,  0.]

def model():
  a_model = pyro.sample("a_model", dist.Normal(0, 10)) ## Note [2] 
  b_model = pyro.sample("b_model", dist.Normal(0, 10))

  x1 = torch.from_numpy(np.array(x11))
  link = torch.nn.functional.softplus(a_model * x1 + b_model) ## Note [3] 

  truncation_label = torch.from_numpy(np.array(y11))
  y2 = torch.from_numpy(np.array(y22))

  for i in range(len(x1)):
     y_hidden_dist = dist.Exponential(1 / link[i]) ## Note [4] 

     # it this event labeled as censored?  
     if truncation_label[i] == 1: 
         y_real = pyro.sample("obs_{}".format(i), 
                                    y_hidden_dist,
                                    obs = y2[i])            
     else:
         truncation_prob = 1 - y_hidden_dist.cdf(y2[i])
         pyro.sample("truncation_label_{}".format(i), 
                   dist.Bernoulli(truncation_prob), 
                   obs = truncation_label[i])


pyro.clear_param_store()
hmc_kernel = HMC(model,
             step_size = 0.1, 
             num_steps = 4)
mcmc_run = MCMC(hmc_kernel, 
            num_samples=5,
            warmup_steps=1).run()
marginal_a = EmpiricalMarginal(mcmc_run, 
                           sites="a_model")
posterior_a = [marginal_a.sample() for i in range(50)] 
sns.distplot(posterior_a)

But this crashes with the error in the topic title. What am I doing wrong?

it’s exactly what the error says: you have some tensors that are FloatTensor and some that are DoubleTensor. you can fix this by converting everything to float:

x1 = torch.from_numpy(np.array(x11)).float()

etc.

also in the future, for ease of debugging, please include the entire reproducible script including the import statements