Help with Hierarchical Linear Regression with Partial Pooling

Hey guys,
I need a help with the following task. I’m training a Hierarchical Linear Regression with Partial Pooling. The code below shows exactly what I’m trying to do. But the idea is pretty straightforward.

My problem is this: while the same code takes 2 minutes in PyMC3, it is taking over 2 hours in Pyro (and still not finished!). So, I must be doing something wrong. Can you help me figure out what I’m doing wrong? To compare, I’m putting both codes, in PyMC3 and in Pyro. I’d love to finally drop PyMC3 for good for always, but I need to figure out how to make Pyro’s sampler to converge faster…

Code in PyMC3

n_patients = train['Patient'].nunique()
FVC_obs = train['FVC'].values
Weeks = train['Weeks'].values
PatientID = train['PatientID'].values

with pm.Model() as model_a:
    mu_a = pm.Normal('mu_a', mu=0., sigma=100)
    sigma_a = pm.HalfNormal('sigma_a', 100.)
    mu_b = pm.Normal('mu_b', mu=0., sigma=100)
    sigma_b = pm.HalfNormal('sigma_b', 100.)

    a = pm.Normal('a', mu=mu_a, sigma=sigma_a, shape=n_patients)
    b = pm.Normal('b', mu=mu_b, sigma=sigma_b, shape=n_patients)

    sigma = pm.HalfNormal('sigma', 100.)

    FVC_est = a[PatientID] + b[PatientID] * Weeks

    # Data likelihood
    FVC_like = pm.Normal('FVC_like', mu=FVC_est,
                         sigma=sigma, observed=FVC_obs)

    # Fitting the model
    trace_a = pm.sample(2000, tune=2000, target_accept=.9, init="adapt_diag")

Code in Pyro

FVC_obs = torch.from_numpy(train['FVC'].values)
Weeks = torch.from_numpy(train['Weeks'].values)
PatientID = torch.from_numpy(train['PatientID'].values)

def model(PatientID, Weeks, FVC_obs=None):
    mu_a = pyro.sample("mu_a", dist.Normal(0., 100.))
    sigma_a = pyro.sample("sigma_a", dist.HalfNormal(100.))
    mu_b = pyro.sample("mu_b", dist.Normal(0., 100.))
    sigma_b = pyro.sample("sigma_b", dist.HalfNormal(100.))
    
    unique_patient_IDs = np.unique(PatientID)
    n_patients = len(unique_patient_IDs)
    
    a = pyro.sample("a", dist.Normal(mu_a.expand(n_patients), sigma_a))
    b = pyro.sample("b", dist.Normal(mu_b.expand(n_patients), sigma_b))
    
    sigma = pyro.sample("sigma", dist.HalfNormal(100.))
    
    FVC_est = a[PatientID] + b[PatientID] * Weeks
    
    with pyro.plate("data", len(PatientID)):
        pyro.sample("obs", dist.Normal(FVC_est, sigma), obs=FVC_obs)

nuts_kernel = NUTS(model) 
mcmc = MCMC(nuts_kernel, num_samples=500, warmup_steps=200)
mcmc.run(PatientID, Weeks, FVC_obs)
posterior_samples = mcmc.get_samples()

Thanks!

For NUTS, we recommend to use numpyro, which will be much faster. The code should be the same except that numpy.ndarray does not have .expand method: so at a and b sites, you can use: dist.Normal(mu_a, sigma_a).expand([n_patients]) instead. You also don’t need to use torch.from_numpy(...).

If it is slow, please let me know. We love to have more benchmark results. :slight_smile:

2 Likes

YES!!! Numpyro works, and is extremely fast!! Faster than PyMC3! :slight_smile:

I will finish the code and contribute with a tutorial, thanks!!

PS: Question: why should I ever use Pyro again? What’s the different use cases for Pyro and Numpyro?

Cheers!

I will finish the code and contribute with a tutorial

It is great to hear! I am looking forward to seeing your tutorial. :slight_smile:

There are many cool stuffs in Pyro that do not appear in NumPyro, for example, see Contributed code section in Pyro docs. For me, while developing, it is much easier to debug PyTorch code than Jax code (though Jax team has put much effort to help debugging in recent releases). Hence to implement a new inference algorithm, it is easier for me to work in Pyro.

2 Likes

There you go:

Thanks!!

1 Like