ReLU BNN doesn't work

Hi,

First of all, thanks for NumPyro, it seems very promising! I hope this question isn’t too stupid :slight_smile:

I’m trying to sample from a BNN using HMC to solve a simple regression task, and it works fine with a tanh activation but just returns a linear model when I try to use ReLU instead. I’ve had this issue for some time now and can’t seem to find anything wrong with the code. I’ve tried other formulations of the activation function, e.g. x*(x>0) but doesn’t seem to make a difference. Any ideas?

Here’s a somewhat minimal example:

import numpy as np
import numpy as onp

from matplotlib import pyplot as plt
from jax import vmap
from jax.config import config as jax_config
import jax.numpy as np
import jax.random as random

import numpyro.distributions as dist
from numpyro.handlers import sample, seed, substitute, trace
from numpyro.hmc_util import initialize_model
from numpyro.mcmc import mcmc

class HMCBNNNP:
    def __init__(self, X, Y):
        jax_config.update('jax_platform_name', 'cpu')

        self.num_warmup = 500
        self.num_samples = 100
        self.hidden_units = 10

        # do inference
        self.rng, self.rng_predict = random.split(random.PRNGKey(0))
        self.samples = run_inference(model, self.rng, X, Y, self.hidden_units, self.num_warmup, self.num_samples)

    def predict(self, x):
        all_preds = []

        for i in range(self.num_samples):
            y_pred = nn(x, self.samples['w1'][i], self.samples['w2'][i], self.samples['w3'][i])
            all_preds.append(y_pred)
        all_preds = np.stack(all_preds)

        y_pred = np.mean(all_preds, axis=0)

        return y_pred

# the non-linearity we use in our neural network
def relu(x):
    return np.maximum(x, 0.)
    #return np.tanh(x)

# a two-layer bayesian neural network with computational flow
# given by D_X => D_H => D_H => D_Y where D_H is the number of
# hidden units. (note we indicate tensor dimensions in the comments)
def model(X, Y, D_H):
    D_X, D_Y = X.shape[1], 1

    w1 = sample("w1", dist.Normal(np.zeros((D_X, D_H)), np.ones((D_X, D_H))))
    z1 = relu(np.matmul(X, w1))

    w2 = sample("w2", dist.Normal(np.zeros((D_H, D_H)), np.ones((D_H, D_H))))
    z2 = relu(np.matmul(z1, w2))

    w3 = sample("w3", dist.Normal(np.zeros((D_H, D_Y)), np.ones((D_H, D_Y))))
    y_pred = np.matmul(z2, w3)

    sigma = 0.01 * np.ones_like(y_pred) # Homoscedastic variance for now.

    # observe data
    sample("Y", dist.Normal(y_pred, sigma), obs=Y)

def nn(x, w1, w2, w3):
    x = relu(np.matmul(x, w1))
    x = relu(np.matmul(x, w2))
    x = np.matmul(x, w3)
    return x

# helper function for HMC inference
def run_inference(model, rng, X, Y, D_H, num_warmup, num_samples):
    init_params, potential_fn, constrain_fn = initialize_model(rng, model, X, Y, D_H)
    samples = mcmc(num_warmup, num_samples, init_params, sampler='hmc', potential_fn=potential_fn, constrain_fn=constrain_fn)
    return samples

if __name__ == '__main__':
    x = np.linspace(-2, 2, 100).reshape(-1, 1)
    y = x**3
    x_test = np.linspace(-3, 3, 100).reshape(-1, 1)

    hmc = HMCBNNNP(x, y)
    y_pred = hmc.predict(x_test)

    plt.scatter(x, y)
    plt.plot(x_test, y_pred)
    plt.show()

hmc requires differentiable probabilities densities, something which your ReLU activation is going to do damage to. if you want something ReLU-like that is still differentiable you might try ELU.

1 Like

Thank you. ELU works, but I’ve encountered a different issue. I have added another layer that outputs the variance, and have seen that initialization (what I call the time before warmup starts) takes very long even with (what seems to me to be) relatively few units. It takes <5 seconds with 10 or 15 units per layer but with 20 units warmup doesn’t even start for me (I’ve waited for at least 10-15 minutes.)

I’ve understood that HMC doesn’t scale fantastically, but this behavior seems a bit strange to me? Or is this to be expected?

import numpy as np

import seaborn as sns; sns.set()
from matplotlib import pyplot as plt
import numpy as onp

from jax import vmap
from jax.config import config as jax_config
import jax.numpy as np
import jax.random as random

import numpyro.distributions as dist
from numpyro.handlers import sample, seed, substitute, trace
from numpyro.hmc_util import initialize_model
from numpyro.mcmc import mcmc

class HMCBNNNP:
    def __init__(self, X, Y):
        jax_config.update('jax_platform_name', 'cpu')

        self.num_warmup = 1000
        self.num_samples = 100
        self.hidden_units = 20

        self.rng, self.rng_predict = random.split(random.PRNGKey(0))
        self.samples = run_inference(model, self.rng, X, Y, self.hidden_units, self.num_warmup, self.num_samples)

    def predict(self, x):
        all_means = []
        all_stds = []

        for i in range(self.num_samples):
            mu, sigma = nn(x, self.samples['w1'][i], self.samples['w2'][i], self.samples['w3'][i], self.samples['w4'][i])
            all_means.append(mu)
            all_stds.append(sigma)
        all_means = np.stack(all_means)
        all_stds = np.stack(all_stds)

        y_pred_mu = np.mean(all_means, axis=0)
        y_pred_std = np.mean(all_stds**2 + all_means**2, axis=0) - y_pred_mu**2

        return y_pred_mu, y_pred_std

def elu(x):
    return x * (x >= 0) + (np.exp(x) - 1) * (x < 0)

def model(X, Y, D_H):
    D_X, D_Y = X.shape[1], 1

    prior_sigma = 1

    w1 = sample("w1", dist.Normal(np.zeros((D_X, D_H)), prior_sigma*np.ones((D_X, D_H)))) 
    z1 = elu(np.matmul(X, w1))

    w2 = sample("w2", dist.Normal(np.zeros((D_H, D_H)), prior_sigma*np.ones((D_H, D_H))))
    z2 = elu(np.matmul(z1, w2))

    w3 = sample("w3", dist.Normal(np.zeros((D_H, D_Y)), prior_sigma*np.ones((D_H, D_Y))))
    w4 = sample("w4", dist.Normal(np.zeros((D_H, D_Y)), prior_sigma*np.ones((D_H, D_Y))))
    
    mu = np.matmul(z2, w3)
    sigma = np.sqrt(np.exp(np.matmul(z2, w4)))

    sample("Y", dist.Normal(mu, sigma), obs=Y)

def nn(x, w1, w2, w3, w4):
    x = elu(np.matmul(x, w1))
    x = elu(np.matmul(x, w2))
    mu = np.matmul(x, w3)
    sigma = np.sqrt(np.exp(np.matmul(x, w4)))
    return mu, sigma

def run_inference(model, rng, X, Y, D_H, num_warmup, num_samples):
    init_params, potential_fn, constrain_fn = initialize_model(rng, model, X, Y, D_H)
    samples = mcmc(num_warmup, num_samples, init_params, potential_fn=potential_fn, constrain_fn=constrain_fn)

    return samples

if __name__ == '__main__':
    x = np.linspace(-2, 2, 100).reshape(-1, 1)
    y = x**3 + onp.random.randn(100, 1)
    x_test = np.linspace(-3, 3, 100).reshape(-1, 1)

    hmc = HMCBNNNP(x, y)
    y_pred_mu, y_pred_std = hmc.predict(x_test)

    plt.scatter(x, y, zorder=1)
    plt.plot(x_test, y_pred_mu, color='red', zorder=2)
    plt.fill_between(x_test.reshape(-1), (y_pred_mu - 2*y_pred_std).reshape(-1), (y_pred_mu + 2*y_pred_std).reshape(-1), color=(255/255, 195/255, 193/255), zorder=-2)
    plt.ylim(-30, 30)
    plt.show()

it’s hard to say but it’s not too surprising since modeling the observation noise like that might be problematic since small changes in weights can lead to pretty big changes in the variance and thus big changes in the log probability

@jboyml I think that hmc just couldn’t find a good initial step_size in this case. That indicates that there is some problems with the model. I guess the number hidden unit is the problem. We’ll make a fix for the problem warmup doesn’t even start, so mcmc will be able to run.

1 Like

@jboyml The issue is fixed in master branch. You might notice that step_size would be very small and the result is really bad for such large hidden_units. :slight_smile: