Hi!
I am trying to implement the HMC BPINN with Numpyro, but I am having some troubles with JAX’s functional programming paradigm. Since I need the gradients to compute the mean, I get an error of declaring two times a numpyro.sample variable but I can’t get my head around it.
This is the implementation so far:
def nonlin(x):
return jnp.tanh(x)
def model_bnn(p, t, D_H, sigma_w=1):
X = jnp.concatenate((p, t), axis=1)
D_X, D_Y = X.shape[1], 1
# sample first layer (we put unit normal priors on all weights)
w1 = numpyro.sample("w1", dist.Normal(jnp.zeros((D_X, D_H)), sigma_w*jnp.ones((D_X, D_H)))) # D_X D_H
b1 = numpyro.sample("b1", dist.Normal(jnp.zeros((D_H, 1)), sigma_w*jnp.ones((D_H, 1)))) # D_H 1
z1 = nonlin(jnp.matmul(X, w1) + jnp.transpose(b1)) # N D_H <= first layer of activations
# sample second layer
w2 = numpyro.sample("w2", dist.Normal(jnp.zeros((D_H, D_H)), sigma_w*jnp.ones((D_H, D_H)))) # D_H D_H
b2 = numpyro.sample("b2", dist.Normal(jnp.zeros((D_H, 1)), sigma_w*jnp.ones((D_H, 1)))) # D_H 1
z2 = nonlin(jnp.matmul(z1, w2) + jnp.transpose(b2)) # N D_H <= second layer of activations
# sample final layer of weights and neural network output
w3 = numpyro.sample("w3", dist.Normal(jnp.zeros((D_H, D_Y)), sigma_w*jnp.ones((D_H, D_Y)))) # D_H D_Y
b3 = numpyro.sample("b3", dist.Normal(jnp.zeros((D_Y, 1)), sigma_w*jnp.ones((D_Y, 1)))) # D_H 1
z3 = jnp.matmul(z2, w3) + jnp.transpose(b3) # N D_Y <= output of the neural network
return z3
grad_bnn = grad(model_bnn, argnums=1)
def model_bpinn(p, t, Y, F, D_H, u_sigma=None, f_sigma=None, sigma_w=1):
m = 0.15
d = 0.15
B = 0.2
u_mu = model_bnn(p, t, D_H, sigma_w)
dudt = grad_bnn(p, t, D_H, sigma_w)
dudtt = grad_bnn(p, dudt, D_H, sigma_w)
# prior on the observation noise
if u_sigma is None:
prec_u = numpyro.sample("prec_u", dist.Gamma(3.0, 1.0))
u_sigma = 1.0 / jnp.sqrt(prec_u)
if f_sigma is None:
prec_f = numpyro.sample("prec_f", dist.Gamma(3.0, 1.0))
f_sigma = 1.0 / jnp.sqrt(prec_f)
# observe data
with numpyro.plate('observations', D_X):
u_hat = numpyro.sample("Y", dist.Normal(u_mu, u_sigma), obs=Y)
f_mu = m * dudtt + d * dudt + B * jnp.sin(u_hat) - p # Forcing physics-term, always=0
f_hat = numpyro.sample("F", dist.Normal(f_mu, f_sigma), obs=F)
return u_hat, f_hat
Any help is appreciated!