I am new to pyro and have been trying the sample code explained here Modeling Censored Time-to-Event Data Using Pyro, an Open Source Probabilistic Programming Language | Uber Blog
Following is the code:
import pyro
import torch
import seaborn as sns
import pyro.distributions as dist
from pyro import infer, optim
from pyro.infer.mcmc import HMC, MCMC
from pyro.infer import EmpiricalMarginal
from matplotlib import pyplot as plt
# Generate Dummy Experiment Data
n = 500
a = 2
b = 4
c = 8
x = dist.Normal(0, 0.34).sample((n,)) # x.shape: (500)
link = torch.nn.functional.softplus(a*x + b) # link.shape: (500)
y = dist.Exponential(rate=1/link).sample() # y.shape: (500)
truncation_label = (y > c).float() # truncation.shape: (500)
y_obs = y.clamp(max=c) # y_obs.shape: (500) values greater than c are replaced by c
# plot actual and observed
actual = sns.regplot(x.numpy(), y.numpy())
obs = sns.regplot(x.numpy(), y_obs.numpy())
plt.legend(['actual', 'observed'])
# Define Model
def model(x, y, truncation_label):
a_model = pyro.sample("a_model", dist.Normal(0, 10))
b_model = pyro.sample("b_model", dist.Normal(0, 10))
link = torch.nn.functional.softplus(a_model * x + b_model)
for i in range(len(x)):
y_hidden_dist = dist.Exponential(rate=1/link[i])
if truncation_label[i] == 0:
y_real = pyro.sample("obs_{}".format(i),
y_hidden_dist,
obs=y[i])
else:
truncation_prob = 1 - y_hidden_dist.cdf(y[i])
pyro.sample("truncation_label_{}".format(i),
dist.Bernoulli(truncation_prob),
obs = truncation_label[i])
# Inference
pyro.clear_param_store()
hmc_kernel = HMC(model, step_size=0.1, num_steps=4)
mcmc_run = MCMC(hmc_kernel, num_samples=5,
warmup_steps=1).run(x, y, truncation_label)
marginal_a = EmpiricalMarginal(mcmc_run, sites='a_model')
posterior_a = [marginal_a.sample() for i in range(50)]
But, it throws AssertionError: trace_dist must be trace posterior distribution object
error at line:
marginal_a = EmpiricalMarginal(mcmc_run, sites='a_model')
Can someone explain this to me? Thanks a lot in advance…
pyro version I am using : 1.2.0