Hierarchical Bayesian Neural Nets

Hi,
I’m new to Pyro and was trying to combine this and this to build a hierarchical bayesian NN. In PyMC3, to turn a non-hierarchical to a hierarchical model, all I need to do is to pass the unique index of relevant level to the shape parameter of the distribution I’m sampling. It seems from the first link this is done in the definition of the plate in pyro. However, when I try to do this, it tells me that it can’t broadcast correctly. I can’t see what I’m doing wrong. Does anyone have an implementation of a hierarchical BNN in pyro/ numpyro I could look at?
This is what I wrote, which is obviously wrong:

def hierarchical_model(StageIndex, X, Y = None, D_H= None):

    D_X, D_Y = X.shape[1], 1

    μ_w1 = numpyro.sample("μ_w1", dist.Normal(jnp.zeros((D_X, D_H)), jnp.ones((D_X, D_H))))
    σ_w1 = numpyro.sample("σ_w1", dist.HalfNormal(100.)

    μ_w2 = numpyro.sample("μ_w2", dist.Normal(jnp.zeros((D_H, D_H)), jnp.ones((D_H, D_H))))
    σ_w2 = numpyro.sample("σ_w2", dist.HalfNormal(100.))

    μ_w3 = numpyro.sample("μ_w3", dist.Normal(jnp.zeros((D_H, D_Y)), jnp.ones((D_H, D_Y))))
    σ_w3 = numpyro.sample("σ_w3", dist.HalfNormal(100.))

    unique_index = np.unique(StageIndex)
    n_members = len(unique_index)
    with numpyro.plate("plate_i", n_members):
        w1_offset = numpyro.sample(
           "w1_offset", dist.Normal(jnp.zeros((D_X, D_H)), jnp.ones((D_X, D_H)))
        )  
        w1 = μ_w1 + w1_offset * σ_w1

        w2_offset = numpyro.sample(
            "w2_offset", dist.Normal(jnp.zeros((D_H, D_H)), jnp.ones((D_H, D_H)))
        )  
        w2 = μ_w2 + w2_offset * σ_w2

        w3_offset = numpyro.sample(
            "w3_offset", dist.Normal(jnp.zeros((D_H, D_Y)), jnp.ones((D_H, D_Y)))
        )  
        w3 = μ_w3 + w3_offset * σ_w3

    z1 = nonlin(jnp.matmul(X, w1[StageIndex]))  
    z2 = nonlin(jnp.matmul(z1, w2[StageIndex])) 
    z3 = jnp.matmul(z2, w3[StageIndex]).squeeze(-1)  

    # we put a prior on the observation noise
    prec_obs = numpyro.sample("prec_obs", dist.Gamma(3.0, 1.0))
    sigma_obs = 1.0 / jnp.sqrt(prec_obs)
    if Y is not None:
        with numpyro.plate("data", X.shape[0]):
            # observe data
            numpyro.sample("Y", dist.Normal(z3, sigma_obs), obs=Y.squeeze(-1))
    else:
        with numpyro.plate("data", X.shape[0]):
            # observe data
            numpyro.sample("Y", dist.Normal(z3, sigma_obs), obs=Y)

Hi June, I think you need to declare those D_X, D_H dimensions as event dimensions (by default, they are batch dimensions):

# dist.Normal(jnp.zeros((D_X, D_H)), jnp.ones((D_X, D_H))).to_event()
dist.Normal(0, 1).expand([D_X, D_H]).to_event()
# or to be more explicit, ...to_event(2)
# to move two rightmost batch dimensions to event dimensions

Similarly, I think you need to declare event dimensions for Y.

Event dimensions are dimensions of a variable that do not appear in a plate diagram. See Tensor shapes in Pyro tutorial. Based on your comment, I guess pymc3 assumes all dimensions of a variable are event dimensions?

Thanks, that helped :slight_smile: I figured I could also declare dim = -3 in the plate and that would do too.

I stumbled on this after trying the same thing, but I’m still not getting it to work correctly. The trouble I have with June’s code is this section here:

z1 = nonlin(jnp.matmul(X, w1[StageIndex]))  
z2 = nonlin(jnp.matmul(z1, w2[StageIndex])) 
z3 = jnp.matmul(z2, w3[StageIndex]).squeeze(-1)  

Since w1/w2/w3 are now n_member x D_X/D_H x D_H matmul does not seem to work with a simple index. Below is my solution. I was able to estimate something, but then tracing the model for predictions does not work due to broadcasting (so I’m not even sure if it is estimating what I want it to). I’ve seen other solutions use a mask for the maximum number of observations within a group, but that also seems a bit sloppy. Does anyone have any other suggestions? Note below I have dim = -3, but also did .expand().to_event() on w1/2/3gm with the same outcome.

def modelHBNN(X, Y, D_H, ID, D_Y=1):
N, D_X = X.shape

n_groups = len(np.unique(ID))

#for testing, add rng_key
#key = random.PRNGKey(1)
#D_H = 4
#with numpyro.handlers.seed(rng_seed=0):
mu1    = numpyro.sample("mu1", dist.Normal(0.0, 500.0))
mu2    = numpyro.sample("mu2", dist.Normal(0.0, 500.0))
mu3    = numpyro.sample("mu3", dist.Normal(0.0, 500.0))
sigma1 = numpyro.sample("sigma1", dist.HalfNormal(100.0))
sigma2 = numpyro.sample("sigma2", dist.HalfNormal(100.0))
sigma3 = numpyro.sample("sigma3", dist.HalfNormal(100.0))
w1gs =  numpyro.sample("w1gs", dist.HalfNormal(100))
w2gs =  numpyro.sample("w2gs", dist.HalfNormal(100))
w3gs =  numpyro.sample("w3gs", dist.HalfNormal(100))
modelsigma = numpyro.sample("modelsigma", dist.HalfNormal(100.0))

w1raw = numpyro.sample("w1raw", dist.Normal(jnp.zeros((D_X, D_H)), jnp.ones((D_X, D_H))))
w2raw = numpyro.sample("w2raw", dist.Normal(jnp.zeros((D_H, D_H)), jnp.ones((D_H, D_H))))
w3raw = numpyro.sample("w3raw", dist.Normal(jnp.zeros((D_H)), jnp.ones((D_H))))

#with numpyro.handlers.seed(rng_seed=0):                    
with numpyro.plate("new", n_groups,dim=-3):
    w1gm =  numpyro.sample("w1gm", dist.Normal(mu1, sigma1))
    w2gm =  numpyro.sample("w2gm", dist.Normal(mu2, sigma2))
    w3gm =  numpyro.sample("w3gm", dist.Normal(mu3, sigma3))        
    
w1    = w1raw*w1gs + w1gm
w2    = w2raw*w2gs + w2gm
w3    = w3raw*w3gs + w3gm

#w1.shape #n_groups x D_X x D_H
#w2.shape #n_groups x D_H x D_H
#w3.shape #n_groups x 1 x D_H
#      
#first layer
w1exp = w1[ID,:,:] #expand to N x D_X x D_H
z1    = jnp.tanh(vmap(jnp.matmul,in_axes=0,out_axes=0)(X,w1exp)) # <= first layer of activations
assert z1.shape == (N, D_H)    

#second layer
w2exp = w2[ID,:,:] #expand to N x D_X x D_H
z2 = jnp.tanh(vmap(jnp.matmul,in_axes=0,out_axes=0)(z1,w2exp))  # <= second layer of activations
assert z2.shape == (N, D_H)

#third layer
w3exp = jnp.squeeze(w3[ID,:,:],1) #expand to N x 1 x D_H and then squeeze out middle -> N x D_H
z3in  = vmap(jnp.matmul,in_axes=0,out_axes=0)(z2,w3exp)  # <= third layer of activations
z3    = jnp.expand_dims(z3in,1)
assert z3.shape == (N, D_Y)

if Y is not None:
    assert z3.shape == Y.shape            

with numpyro.plate("data", N):
    # note we use to_event(1) because each observation has shape (1,)
    numpyro.sample("Y", dist.Normal(z3, modelsigma).to_event(1), obs=Y)