AssertionError: trace_dist must be trace posterior distribution object

I am new to pyro and have been trying the sample code explained here https://eng.uber.com/modeling-censored-time-to-event-data-using-pyro/

Following is the code:

    import pyro
    import torch
    import seaborn as sns
    import pyro.distributions as dist
    from pyro import infer, optim
    from pyro.infer.mcmc import HMC, MCMC
    from pyro.infer import EmpiricalMarginal
    from matplotlib import pyplot as plt


    # Generate Dummy Experiment Data
    n = 500
    a = 2
    b = 4
    c = 8

    x = dist.Normal(0, 0.34).sample((n,)) # x.shape: (500)

    link = torch.nn.functional.softplus(a*x + b) # link.shape: (500)

    y = dist.Exponential(rate=1/link).sample() # y.shape: (500)

    truncation_label = (y > c).float() # truncation.shape: (500)

    y_obs = y.clamp(max=c)          # y_obs.shape: (500)  values greater than c are replaced by c



    # plot actual and observed
    actual = sns.regplot(x.numpy(), y.numpy())
    obs = sns.regplot(x.numpy(), y_obs.numpy())
    plt.legend(['actual', 'observed'])


    # Define Model
    def model(x, y, truncation_label):
        a_model = pyro.sample("a_model", dist.Normal(0, 10))
        b_model = pyro.sample("b_model", dist.Normal(0, 10))
    
        link = torch.nn.functional.softplus(a_model * x + b_model)
    
        for i in range(len(x)):
            y_hidden_dist = dist.Exponential(rate=1/link[i])
        
            if truncation_label[i] == 0:
                y_real = pyro.sample("obs_{}".format(i),
                                 y_hidden_dist,
                                 obs=y[i])
            else:
                truncation_prob = 1 - y_hidden_dist.cdf(y[i])
                pyro.sample("truncation_label_{}".format(i),
                        dist.Bernoulli(truncation_prob),
                        obs = truncation_label[i])


    # Inference
    pyro.clear_param_store()
 
    hmc_kernel = HMC(model, step_size=0.1, num_steps=4)

    mcmc_run = MCMC(hmc_kernel, num_samples=5,
                warmup_steps=1).run(x, y, truncation_label)

    marginal_a = EmpiricalMarginal(mcmc_run, sites='a_model')
    posterior_a = [marginal_a.sample() for i in range(50)]

But, it throws AssertionError: trace_dist must be trace posterior distribution object error at line:

marginal_a = EmpiricalMarginal(mcmc_run, sites='a_model')

Can someone explain this to me? Thanks a lot in advance…

pyro version I am using : 1.2.0

Hi @Abhishek, EmpiricalMarginal class is deprecated. For now, you can get samples from MCMC with mcmc_run.get_samples().

1 Like

Thanks a lot @fehiepsi for the help.

Just for clarification, adding the modified code here:

        mcmc = MCMC(hmc_kernel, num_samples=5,
                warmup_steps=1)
        mcmc.run(x, y, turncation_label)

        posterior_a = mcmc.get_samples(50)['a_model']