Hi, I’m working my way through Statistical Rethinking with the adapted Pyro code. The code has been helpful so far, although the latest version of the textbook no longer matches up 1:1 with the adapted code, and some code is missing entirely.
Additionally, the function in the provided module rethinking.py that does quadratic approximation doesn’t work in my experience. Since I won’t be using quadratic approximation in practice, I figured I’d try to use SVI.
For the linear regression from Chapter 4 (model 4.3), I followed the Introduction to Pyro tutorial as closely as possible. However, the intercept parameter seems to be way off in the posterior check.
def model(weight=None, height=None):
a = pyro.sample("a", dist.Normal(160, 20))
b = pyro.sample("b", dist.Normal(0, 10))
sigma = pyro.sample("sigma", dist.Uniform(0, 50))
mu = a + b * weight
with pyro.plate('data'):
return pyro.sample("height", dist.Normal(mu, sigma), obs=height)
pyro.clear_param_store()
# These should be reset each training loop.
auto_guide = pyro.infer.autoguide.AutoNormal(model)
adam = pyro.optim.Adam({"lr": 0.02}) # Consider decreasing learning rate.
elbo = pyro.infer.Trace_ELBO()
svi = pyro.infer.SVI(model, auto_guide, adam, elbo)
losses = []
for step in range(1000): # Consider running for more steps.
loss = svi.step(weight_c, height)
losses.append(loss)
if step % 100 == 0:
print("Elbo loss: {}".format(loss))
predictive = pyro.infer.Predictive(model, guide=auto_guide, num_samples=800)
svi_samples = predictive(weight_c, height=None)
svi_height = svi_samples["height"]
predictions = pd.DataFrame({
'weight': weight_c,
"h_mean": svi_height.mean(0).detach().cpu().numpy(),
"h_perc_5": svi_height.kthvalue(int(len(svi_height) * 0.05), dim=0)[0].detach().cpu().numpy(),
"h_perc_95": svi_height.kthvalue(int(len(svi_height) * 0.95), dim=0)[0].detach().cpu().numpy(),
'true_height': height,
}).sort_values(by='weight')
f, ax = plt.subplots(figsize=(12,8))
ax.plot(predictions['weight'], predictions['h_mean'])
ax.fill_between(predictions['weight'], predictions['h_perc_5'], predictions['h_perc_95'], alpha=0.5)
ax.plot(predictions['weight'], predictions['true_height'], "o")
ax.set(xlabel='weight (centered)', ylabel='height')
where weight_c
is the centered weight.
Slope (parameter b
) seems correct, but the intercept (parameter a
) seems way off (Normal(15,1) when it should be around Normal(160,20) as in my prior).
Additionally, it’s not clear to me how to get the variance-covariance matrix from this model as shown in the textbook.
Thoughts much appreciated.