Hierarchical BNN

Hi everyone!

I’m quite new at probabilistic programming. I’m trying to convert simple hierarchical two-layer bayesian neural network for make moons dataset from pymc3 to numpyro. You can see the pymc3 version at pyprobml.

Here is the implementation.

'''
A two-layer bayesian neural network with computational flow
given by D_X => D_H => D_H => D_Y where D_H is the number of
hidden units.
'''
def model(X, Y, D_H):
    N, D_X = X.shape
    D_Y = 1  
    # sample first layer (we put unit normal priors on all weights)
    w1 = numpyro.sample("w1", dist.Normal(jnp.zeros((D_X, D_H)), jnp.ones((D_X, D_H))))  # D_X D_H
    z1 = jnp.tanh(jnp.matmul(X, w1))   # N D_H  <= first layer of activations

    # sample second layer
    w2 = numpyro.sample("w2", dist.Normal(jnp.zeros((D_H, D_H)), jnp.ones((D_H, D_H))))  # D_H D_H
    z2 = jnp.tanh(jnp.matmul(z1, w2))  # N D_H  <= second layer of activations

    # sample final layer of weights and neural network output
    w3 = numpyro.sample("w3", dist.Normal(jnp.zeros((D_H, D_Y)), jnp.ones((D_H, D_Y))))  # D_H D_Y
    z3 = jnp.matmul(z2, w3)  # N D_Y  <= output of the neural network
    # observe data
    with numpyro.plate("classes", N):
      Y = numpyro.sample("Y", dist.Bernoulli(logits=z3), obs=Y)
    #return Y

def run_inference(model, rng_key, X, Y, D_H):
    kernel = NUTS(model)
    guide = AutoDiagonalNormal(model)
    
    svi = SVI(model, guide, optim=numpyro.optim.Adam(step_size=1e-2), loss=numpyro.infer.Trace_ELBO())
    svi_result = svi.run(rng_key, 10000, X, Y, D_H)
    params = svi_result.params
    svi_state = svi.init(rng_key, X, Y, D_H)

pred_train, pred_test = [], []
D_H = 5

for X_train, Y_train, X_test, Y_test in zip(Xs_train, Ys_train, Xs_test, Ys_test): 
    # do inference
    rng_key, rng_key_train, rng_key_test = random.split(random.PRNGKey(0), 3)
    samples = run_inference(model, rng_key, X_train, Y_train, D_H)

    # predict Y_train at inputs X_train
    vmap_args = (samples, random.split(rng_key_train, 1000))
    predictions = vmap(lambda samples, rng_key: predict(model, rng_key, samples, X_train, D_H))(*vmap_args)
    predictions = predictions[..., 0]
 
    # compute mean prediction and confidence interval around median
    mean_prediction = jnp.mean(predictions, axis=0)
    pred_train.append(mean_prediction > 0.5)
    print(mean_prediction)
    break

Here are the results.

> [[0.473 0.496 0.517 0.504 0.512 0.501 0.511 0.484 0.487 0.49  0.465 0.506
>   0.494 0.496 0.524 0.488 0.521 0.478 0.499 0.472 0.493 0.52  0.489 0.508
>   0.502 0.517 0.501 0.498 0.502 0.515 0.509 0.509 0.489 0.501 0.516 0.49
>   0.495 0.505 0.5   0.503 0.474 0.51  0.484 0.47  0.506 0.512 0.516 0.516
>   0.5   0.497]]

I have examined all documentations, examples and codes under the hood at numpyro. However, I could not solve the problem that the loss slightly decreases and then remains firm at very high value. Furthermore, the mean predictions are so close to one half. It implies that it cannot learn anything and the prediction of any sample point is just a coin flip.

Thanks for your answer in advance.

I have already solved the problem. This line works for me :slight_smile:

Y = numpyro.sample("Y", dist.Bernoulli(logits=z3), obs=Y.reshape((-1,1)) if Y else Y)

Anyway. Thanks a lot…

I am not sure if that’s what you want. Y.reshape((-1,1)) says that the site "Y" has value with shape (N, 1). The plate that you used

with numpyro.plate("classes", N):

says that the last dimension is the plate dimension. So numpyro will think that your site "Y" has shape (N, N).

Because z3 has shape (N, 1), two ways to fix the issue are:

# remove the last singleton dimension of z3
Y = numpyro.sample("Y", dist.Bernoulli(logits=z3.squeeze(-1)), obs=Y)

or

# declare the last singleton dimension is an event dimension
Y = numpyro.sample("Y", dist.Bernoulli(logits=z3).to_event(1), obs=Y.reshape((-1, 1))
1 Like

Did you work out how to make this hierarchical? I’d be interested as well. I have recently asked a recent question here.

Sorry for the late response. Thanks for both of suggestions :slight_smile: