Hi,
I’m new to Pyro and was trying to combine this and this to build a hierarchical bayesian NN. In PyMC3, to turn a non-hierarchical to a hierarchical model, all I need to do is to pass the unique index of relevant level to the shape parameter of the distribution I’m sampling. It seems from the first link this is done in the definition of the plate in pyro. However, when I try to do this, it tells me that it can’t broadcast correctly. I can’t see what I’m doing wrong. Does anyone have an implementation of a hierarchical BNN in pyro/ numpyro I could look at?
This is what I wrote, which is obviously wrong:
def hierarchical_model(StageIndex, X, Y = None, D_H= None):
D_X, D_Y = X.shape[1], 1
μ_w1 = numpyro.sample("μ_w1", dist.Normal(jnp.zeros((D_X, D_H)), jnp.ones((D_X, D_H))))
σ_w1 = numpyro.sample("σ_w1", dist.HalfNormal(100.)
μ_w2 = numpyro.sample("μ_w2", dist.Normal(jnp.zeros((D_H, D_H)), jnp.ones((D_H, D_H))))
σ_w2 = numpyro.sample("σ_w2", dist.HalfNormal(100.))
μ_w3 = numpyro.sample("μ_w3", dist.Normal(jnp.zeros((D_H, D_Y)), jnp.ones((D_H, D_Y))))
σ_w3 = numpyro.sample("σ_w3", dist.HalfNormal(100.))
unique_index = np.unique(StageIndex)
n_members = len(unique_index)
with numpyro.plate("plate_i", n_members):
w1_offset = numpyro.sample(
"w1_offset", dist.Normal(jnp.zeros((D_X, D_H)), jnp.ones((D_X, D_H)))
)
w1 = μ_w1 + w1_offset * σ_w1
w2_offset = numpyro.sample(
"w2_offset", dist.Normal(jnp.zeros((D_H, D_H)), jnp.ones((D_H, D_H)))
)
w2 = μ_w2 + w2_offset * σ_w2
w3_offset = numpyro.sample(
"w3_offset", dist.Normal(jnp.zeros((D_H, D_Y)), jnp.ones((D_H, D_Y)))
)
w3 = μ_w3 + w3_offset * σ_w3
z1 = nonlin(jnp.matmul(X, w1[StageIndex]))
z2 = nonlin(jnp.matmul(z1, w2[StageIndex]))
z3 = jnp.matmul(z2, w3[StageIndex]).squeeze(-1)
# we put a prior on the observation noise
prec_obs = numpyro.sample("prec_obs", dist.Gamma(3.0, 1.0))
sigma_obs = 1.0 / jnp.sqrt(prec_obs)
if Y is not None:
with numpyro.plate("data", X.shape[0]):
# observe data
numpyro.sample("Y", dist.Normal(z3, sigma_obs), obs=Y.squeeze(-1))
else:
with numpyro.plate("data", X.shape[0]):
# observe data
numpyro.sample("Y", dist.Normal(z3, sigma_obs), obs=Y)