Hey guys,
I need a help with the following task. I’m training a Hierarchical Linear Regression with Partial Pooling. The code below shows exactly what I’m trying to do. But the idea is pretty straightforward.
My problem is this: while the same code takes 2 minutes in PyMC3, it is taking over 2 hours in Pyro (and still not finished!). So, I must be doing something wrong. Can you help me figure out what I’m doing wrong? To compare, I’m putting both codes, in PyMC3 and in Pyro. I’d love to finally drop PyMC3 for good for always, but I need to figure out how to make Pyro’s sampler to converge faster…
Code in PyMC3
n_patients = train['Patient'].nunique()
FVC_obs = train['FVC'].values
Weeks = train['Weeks'].values
PatientID = train['PatientID'].values
with pm.Model() as model_a:
mu_a = pm.Normal('mu_a', mu=0., sigma=100)
sigma_a = pm.HalfNormal('sigma_a', 100.)
mu_b = pm.Normal('mu_b', mu=0., sigma=100)
sigma_b = pm.HalfNormal('sigma_b', 100.)
a = pm.Normal('a', mu=mu_a, sigma=sigma_a, shape=n_patients)
b = pm.Normal('b', mu=mu_b, sigma=sigma_b, shape=n_patients)
sigma = pm.HalfNormal('sigma', 100.)
FVC_est = a[PatientID] + b[PatientID] * Weeks
# Data likelihood
FVC_like = pm.Normal('FVC_like', mu=FVC_est,
sigma=sigma, observed=FVC_obs)
# Fitting the model
trace_a = pm.sample(2000, tune=2000, target_accept=.9, init="adapt_diag")
Code in Pyro
FVC_obs = torch.from_numpy(train['FVC'].values)
Weeks = torch.from_numpy(train['Weeks'].values)
PatientID = torch.from_numpy(train['PatientID'].values)
def model(PatientID, Weeks, FVC_obs=None):
mu_a = pyro.sample("mu_a", dist.Normal(0., 100.))
sigma_a = pyro.sample("sigma_a", dist.HalfNormal(100.))
mu_b = pyro.sample("mu_b", dist.Normal(0., 100.))
sigma_b = pyro.sample("sigma_b", dist.HalfNormal(100.))
unique_patient_IDs = np.unique(PatientID)
n_patients = len(unique_patient_IDs)
a = pyro.sample("a", dist.Normal(mu_a.expand(n_patients), sigma_a))
b = pyro.sample("b", dist.Normal(mu_b.expand(n_patients), sigma_b))
sigma = pyro.sample("sigma", dist.HalfNormal(100.))
FVC_est = a[PatientID] + b[PatientID] * Weeks
with pyro.plate("data", len(PatientID)):
pyro.sample("obs", dist.Normal(FVC_est, sigma), obs=FVC_obs)
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_samples=500, warmup_steps=200)
mcmc.run(PatientID, Weeks, FVC_obs)
posterior_samples = mcmc.get_samples()
Thanks!