First of all, thanks for NumPyro, it seems very promising! I hope this question isn’t too stupid
I’m trying to sample from a BNN using HMC to solve a simple regression task, and it works fine with a tanh activation but just returns a linear model when I try to use ReLU instead. I’ve had this issue for some time now and can’t seem to find anything wrong with the code. I’ve tried other formulations of the activation function, e.g. x*(x>0) but doesn’t seem to make a difference. Any ideas?
Here’s a somewhat minimal example:
import numpy as np import numpy as onp from matplotlib import pyplot as plt from jax import vmap from jax.config import config as jax_config import jax.numpy as np import jax.random as random import numpyro.distributions as dist from numpyro.handlers import sample, seed, substitute, trace from numpyro.hmc_util import initialize_model from numpyro.mcmc import mcmc class HMCBNNNP: def __init__(self, X, Y): jax_config.update('jax_platform_name', 'cpu') self.num_warmup = 500 self.num_samples = 100 self.hidden_units = 10 # do inference self.rng, self.rng_predict = random.split(random.PRNGKey(0)) self.samples = run_inference(model, self.rng, X, Y, self.hidden_units, self.num_warmup, self.num_samples) def predict(self, x): all_preds =  for i in range(self.num_samples): y_pred = nn(x, self.samples['w1'][i], self.samples['w2'][i], self.samples['w3'][i]) all_preds.append(y_pred) all_preds = np.stack(all_preds) y_pred = np.mean(all_preds, axis=0) return y_pred # the non-linearity we use in our neural network def relu(x): return np.maximum(x, 0.) #return np.tanh(x) # a two-layer bayesian neural network with computational flow # given by D_X => D_H => D_H => D_Y where D_H is the number of # hidden units. (note we indicate tensor dimensions in the comments) def model(X, Y, D_H): D_X, D_Y = X.shape, 1 w1 = sample("w1", dist.Normal(np.zeros((D_X, D_H)), np.ones((D_X, D_H)))) z1 = relu(np.matmul(X, w1)) w2 = sample("w2", dist.Normal(np.zeros((D_H, D_H)), np.ones((D_H, D_H)))) z2 = relu(np.matmul(z1, w2)) w3 = sample("w3", dist.Normal(np.zeros((D_H, D_Y)), np.ones((D_H, D_Y)))) y_pred = np.matmul(z2, w3) sigma = 0.01 * np.ones_like(y_pred) # Homoscedastic variance for now. # observe data sample("Y", dist.Normal(y_pred, sigma), obs=Y) def nn(x, w1, w2, w3): x = relu(np.matmul(x, w1)) x = relu(np.matmul(x, w2)) x = np.matmul(x, w3) return x # helper function for HMC inference def run_inference(model, rng, X, Y, D_H, num_warmup, num_samples): init_params, potential_fn, constrain_fn = initialize_model(rng, model, X, Y, D_H) samples = mcmc(num_warmup, num_samples, init_params, sampler='hmc', potential_fn=potential_fn, constrain_fn=constrain_fn) return samples if __name__ == '__main__': x = np.linspace(-2, 2, 100).reshape(-1, 1) y = x**3 x_test = np.linspace(-3, 3, 100).reshape(-1, 1) hmc = HMCBNNNP(x, y) y_pred = hmc.predict(x_test) plt.scatter(x, y) plt.plot(x_test, y_pred) plt.show()