Hi. I’m trying to understand hierarchical models, and their Pyro implementation, starting with the eight schools example in the github repository (the SVI one).
At first I was confused with the formulation of the model, until I realised it was a “non-centred” implementation, as described here. Now, I’m trying to make a “centred” version, which should be easy, but is really confusing me.
This is the original “non-centred version”:
def model(data):
y = data[:, 0]
sigma = data[:, 1]
eta = pyro.sample('eta', dist.Normal(torch.zeros(J), torch.ones(J)))
mu = pyro.sample('mu', dist.Normal(torch.zeros(1), 10 * torch.ones(1)))
tau = pyro.sample('tau', dist.HalfCauchy(scale=25 * torch.ones(1)))
theta = mu + tau * eta
pyro.sample("obs", dist.Normal(theta, sigma), obs=y)
And this is my “centred” model:
def model_centred(data):
y = data[:, 0]
sigma = data[:, 1]
# eta = pyro.sample('eta', dist.Normal(torch.zeros(J), torch.ones(J)))
mu = pyro.sample('mu', dist.Normal(torch.zeros(J), 10 * torch.ones(J)))
tau = pyro.sample('tau', dist.HalfCauchy(scale=25 * torch.ones(J)))
# theta = mu + tau * eta
theta = pyro.sample('theta', dist.Normal(mu, tau))
pyro.sample("obs", dist.Normal(theta, sigma), obs=y)
I’ve basically just removed eta, sampling theta from a Normal instead, and changed the dimensions of mu and tau to J. In the guide, I removed eta, and changed the dimensions of mu and tau to J.
Do these changes make sense? Are there other changes I should make?
The model runs ok, but I’m not sure how to interpret the output.