Hierarchical Bayesian Neural Nets

I’m new to Pyro and was trying to combine this and this to build a hierarchical bayesian NN. In PyMC3, to turn a non-hierarchical to a hierarchical model, all I need to do is to pass the unique index of relevant level to the shape parameter of the distribution I’m sampling. It seems from the first link this is done in the definition of the plate in pyro. However, when I try to do this, it tells me that it can’t broadcast correctly. I can’t see what I’m doing wrong. Does anyone have an implementation of a hierarchical BNN in pyro/ numpyro I could look at?
This is what I wrote, which is obviously wrong:

def hierarchical_model(StageIndex, X, Y = None, D_H= None):

    D_X, D_Y = X.shape[1], 1

    μ_w1 = numpyro.sample("μ_w1", dist.Normal(jnp.zeros((D_X, D_H)), jnp.ones((D_X, D_H))))
    σ_w1 = numpyro.sample("σ_w1", dist.HalfNormal(100.)

    μ_w2 = numpyro.sample("μ_w2", dist.Normal(jnp.zeros((D_H, D_H)), jnp.ones((D_H, D_H))))
    σ_w2 = numpyro.sample("σ_w2", dist.HalfNormal(100.))

    μ_w3 = numpyro.sample("μ_w3", dist.Normal(jnp.zeros((D_H, D_Y)), jnp.ones((D_H, D_Y))))
    σ_w3 = numpyro.sample("σ_w3", dist.HalfNormal(100.))

    unique_index = np.unique(StageIndex)
    n_members = len(unique_index)
    with numpyro.plate("plate_i", n_members):
        w1_offset = numpyro.sample(
           "w1_offset", dist.Normal(jnp.zeros((D_X, D_H)), jnp.ones((D_X, D_H)))
        w1 = μ_w1 + w1_offset * σ_w1

        w2_offset = numpyro.sample(
            "w2_offset", dist.Normal(jnp.zeros((D_H, D_H)), jnp.ones((D_H, D_H)))
        w2 = μ_w2 + w2_offset * σ_w2

        w3_offset = numpyro.sample(
            "w3_offset", dist.Normal(jnp.zeros((D_H, D_Y)), jnp.ones((D_H, D_Y)))
        w3 = μ_w3 + w3_offset * σ_w3

    z1 = nonlin(jnp.matmul(X, w1[StageIndex]))  
    z2 = nonlin(jnp.matmul(z1, w2[StageIndex])) 
    z3 = jnp.matmul(z2, w3[StageIndex]).squeeze(-1)  

    # we put a prior on the observation noise
    prec_obs = numpyro.sample("prec_obs", dist.Gamma(3.0, 1.0))
    sigma_obs = 1.0 / jnp.sqrt(prec_obs)
    if Y is not None:
        with numpyro.plate("data", X.shape[0]):
            # observe data
            numpyro.sample("Y", dist.Normal(z3, sigma_obs), obs=Y.squeeze(-1))
        with numpyro.plate("data", X.shape[0]):
            # observe data
            numpyro.sample("Y", dist.Normal(z3, sigma_obs), obs=Y)

Hi June, I think you need to declare those D_X, D_H dimensions as event dimensions (by default, they are batch dimensions):

# dist.Normal(jnp.zeros((D_X, D_H)), jnp.ones((D_X, D_H))).to_event()
dist.Normal(0, 1).expand([D_X, D_H]).to_event()
# or to be more explicit, ...to_event(2)
# to move two rightmost batch dimensions to event dimensions

Similarly, I think you need to declare event dimensions for Y.

Event dimensions are dimensions of a variable that do not appear in a plate diagram. See Tensor shapes in Pyro tutorial. Based on your comment, I guess pymc3 assumes all dimensions of a variable are event dimensions?

Thanks, that helped :slight_smile: I figured I could also declare dim = -3 in the plate and that would do too.