Hi,
First of all, thanks for NumPyro, it seems very promising! I hope this question isn’t too stupid
I’m trying to sample from a BNN using HMC to solve a simple regression task, and it works fine with a tanh activation but just returns a linear model when I try to use ReLU instead. I’ve had this issue for some time now and can’t seem to find anything wrong with the code. I’ve tried other formulations of the activation function, e.g. x*(x>0) but doesn’t seem to make a difference. Any ideas?
Here’s a somewhat minimal example:
import numpy as np
import numpy as onp
from matplotlib import pyplot as plt
from jax import vmap
from jax.config import config as jax_config
import jax.numpy as np
import jax.random as random
import numpyro.distributions as dist
from numpyro.handlers import sample, seed, substitute, trace
from numpyro.hmc_util import initialize_model
from numpyro.mcmc import mcmc
class HMCBNNNP:
def __init__(self, X, Y):
jax_config.update('jax_platform_name', 'cpu')
self.num_warmup = 500
self.num_samples = 100
self.hidden_units = 10
# do inference
self.rng, self.rng_predict = random.split(random.PRNGKey(0))
self.samples = run_inference(model, self.rng, X, Y, self.hidden_units, self.num_warmup, self.num_samples)
def predict(self, x):
all_preds = []
for i in range(self.num_samples):
y_pred = nn(x, self.samples['w1'][i], self.samples['w2'][i], self.samples['w3'][i])
all_preds.append(y_pred)
all_preds = np.stack(all_preds)
y_pred = np.mean(all_preds, axis=0)
return y_pred
# the non-linearity we use in our neural network
def relu(x):
return np.maximum(x, 0.)
#return np.tanh(x)
# a two-layer bayesian neural network with computational flow
# given by D_X => D_H => D_H => D_Y where D_H is the number of
# hidden units. (note we indicate tensor dimensions in the comments)
def model(X, Y, D_H):
D_X, D_Y = X.shape[1], 1
w1 = sample("w1", dist.Normal(np.zeros((D_X, D_H)), np.ones((D_X, D_H))))
z1 = relu(np.matmul(X, w1))
w2 = sample("w2", dist.Normal(np.zeros((D_H, D_H)), np.ones((D_H, D_H))))
z2 = relu(np.matmul(z1, w2))
w3 = sample("w3", dist.Normal(np.zeros((D_H, D_Y)), np.ones((D_H, D_Y))))
y_pred = np.matmul(z2, w3)
sigma = 0.01 * np.ones_like(y_pred) # Homoscedastic variance for now.
# observe data
sample("Y", dist.Normal(y_pred, sigma), obs=Y)
def nn(x, w1, w2, w3):
x = relu(np.matmul(x, w1))
x = relu(np.matmul(x, w2))
x = np.matmul(x, w3)
return x
# helper function for HMC inference
def run_inference(model, rng, X, Y, D_H, num_warmup, num_samples):
init_params, potential_fn, constrain_fn = initialize_model(rng, model, X, Y, D_H)
samples = mcmc(num_warmup, num_samples, init_params, sampler='hmc', potential_fn=potential_fn, constrain_fn=constrain_fn)
return samples
if __name__ == '__main__':
x = np.linspace(-2, 2, 100).reshape(-1, 1)
y = x**3
x_test = np.linspace(-3, 3, 100).reshape(-1, 1)
hmc = HMCBNNNP(x, y)
y_pred = hmc.predict(x_test)
plt.scatter(x, y)
plt.plot(x_test, y_pred)
plt.show()