Hi!

I am trying to implement the HMC BPINN with Numpyro, but I am having some troubles with JAX’s functional programming paradigm. Since I need the gradients to compute the mean, I get an error of declaring two times a numpyro.sample variable but I can’t get my head around it.

This is the implementation so far:

``````def nonlin(x):
return jnp.tanh(x)

def model_bnn(p, t, D_H, sigma_w=1):

X = jnp.concatenate((p, t), axis=1)

D_X, D_Y = X.shape[1], 1

# sample first layer (we put unit normal priors on all weights)
w1 = numpyro.sample("w1", dist.Normal(jnp.zeros((D_X, D_H)), sigma_w*jnp.ones((D_X, D_H))))  # D_X D_H
b1 = numpyro.sample("b1", dist.Normal(jnp.zeros((D_H, 1)), sigma_w*jnp.ones((D_H, 1))))  # D_H 1
z1 = nonlin(jnp.matmul(X, w1) + jnp.transpose(b1))   # N D_H  <= first layer of activations

# sample second layer
w2 = numpyro.sample("w2", dist.Normal(jnp.zeros((D_H, D_H)), sigma_w*jnp.ones((D_H, D_H))))  # D_H D_H
b2 = numpyro.sample("b2", dist.Normal(jnp.zeros((D_H, 1)), sigma_w*jnp.ones((D_H, 1))))  # D_H 1
z2 = nonlin(jnp.matmul(z1, w2) + jnp.transpose(b2))  # N D_H  <= second layer of activations

# sample final layer of weights and neural network output
w3 = numpyro.sample("w3", dist.Normal(jnp.zeros((D_H, D_Y)), sigma_w*jnp.ones((D_H, D_Y))))  # D_H D_Y
b3 = numpyro.sample("b3", dist.Normal(jnp.zeros((D_Y, 1)), sigma_w*jnp.ones((D_Y, 1))))  # D_H 1
z3 = jnp.matmul(z2, w3) + jnp.transpose(b3)  # N D_Y  <= output of the neural network

return z3

def model_bpinn(p, t, Y, F, D_H, u_sigma=None, f_sigma=None, sigma_w=1):

m = 0.15
d = 0.15
B = 0.2

u_mu = model_bnn(p, t, D_H, sigma_w)
dudt = grad_bnn(p, t, D_H, sigma_w)
dudtt = grad_bnn(p, dudt, D_H, sigma_w)

# prior on the observation noise
if u_sigma is None:
prec_u = numpyro.sample("prec_u", dist.Gamma(3.0, 1.0))
u_sigma = 1.0 / jnp.sqrt(prec_u)
if f_sigma is None:
prec_f = numpyro.sample("prec_f", dist.Gamma(3.0, 1.0))
f_sigma = 1.0 / jnp.sqrt(prec_f)

# observe data
with numpyro.plate('observations', D_X):
u_hat = numpyro.sample("Y", dist.Normal(u_mu, u_sigma), obs=Y)
f_mu = m * dudtt + d * dudt + B * jnp.sin(u_hat) - p # Forcing physics-term, always=0
f_hat = numpyro.sample("F", dist.Normal(f_mu, f_sigma), obs=F)

return u_hat, f_hat
``````

Any help is appreciated!

i believe you need to pull your `sample` statements outside of `model_bnn` and then pass them in as args. the problem is that you effectively call `model_bnn` three times and thus duplicate those sample statements. let us know if that doesn’t work!

Thanks, worked perfectly! Now, is there a cleaner way to get the first and second gradients compared to my approach? The MCMC sampling is now taking hours, compared to a few minutes if not computing the first and second derivative. First I tried with grad(model)(args) and grad(grad(model))(args) but given that the output of the BNN is a (n,1) vector, I had to use jacfwd and the hessian.

Notice that p and t are (n,1) vectors, respectively.

``````def nonlin(x):
return jnp.tanh(x)

def model_bnn(p, t, w1, b1, w2, b2, w3, b3):
X = jnp.concatenate((p, t), axis=1)
z1 = nonlin(jnp.matmul(X, w1) + jnp.transpose(b1))
z2 = nonlin(jnp.matmul(z1, w2) + jnp.transpose(b2))
z3 = jnp.matmul(z2, w3) + jnp.transpose(b3)
return z3.squeeze()

def model_bpinn(p, t, Y, F, D_H, u_sigma=None, f_sigma=None, sigma_w=1):

m = 0.15
d = 0.15
B = 0.2

D_X, D_Y = 2, 1

# sample first layer (we put unit normal priors on all weights)
w1 = numpyro.sample("w1", dist.Normal(jnp.zeros((D_X, D_H)), sigma_w*jnp.ones((D_X, D_H))))  # D_X D_H
b1 = numpyro.sample("b1", dist.Normal(jnp.zeros((D_H, 1)), sigma_w*jnp.ones((D_H, 1))))  # D_H 1
# sample second layer
w2 = numpyro.sample("w2", dist.Normal(jnp.zeros((D_H, D_H)), sigma_w*jnp.ones((D_H, D_H))))  # D_H D_H
b2 = numpyro.sample("b2", dist.Normal(jnp.zeros((D_H, 1)), sigma_w*jnp.ones((D_H, 1))))  # D_H 1
# sample final layer of weights and neural network output
w3 = numpyro.sample("w3", dist.Normal(jnp.zeros((D_H, D_Y)), sigma_w*jnp.ones((D_H, D_Y))))  # D_H D_Y
b3 = numpyro.sample("b3", dist.Normal(jnp.zeros((D_Y, 1)), sigma_w*jnp.ones((D_Y, 1))))  # D_H 1

u_mu = model_bnn(p, t, w1, b1, w2, b2, w3, b3).reshape(-1,1)
dudt = jnp.diagonal(first_grad(p, t, w1, b1, w2, b2, w3, b3).squeeze()).reshape(-1,1)
dudtt = jnp.diagonal(jnp.diagonal(second_grad(p, t, w1, b1, w2, b2, w3, b3).squeeze())).reshape(-1,1)

# prior on the observation noise
if u_sigma is None:
prec_u = numpyro.sample("prec_u", dist.Gamma(3.0, 1.0))
u_sigma = 1.0 / jnp.sqrt(prec_u)
if f_sigma is None:
prec_f = numpyro.sample("prec_f", dist.Gamma(3.0, 1.0))
f_sigma = 1.0 / jnp.sqrt(prec_f)

# observe data
with numpyro.plate('observations', p.shape[0]):
u_hat = numpyro.sample("Y", dist.Normal(u_mu, u_sigma), obs=Y)
f_mu = m * dudtt + d * dudt + B * jnp.sin(u_hat) - p # Forcing physics-term, always=0
f_hat = numpyro.sample("F", dist.Normal(f_mu, f_sigma), obs=F)

return u_mu, f_mu
``````

Thanks!

i don’t think so.

note that NUTS will run `model_bpinn` and compute the gradient of the resulting log density dozens of times per iteration. this is inherently expensive unless all your dimensions are very small.

I think you can use `vmap(grad(f))(x)` to compute a batch of grads of `f` given `x`. Similar for the second derivatives `vmap(grad(grad(f)))(x)`.

1 Like

That was a beast optimization, managed to reduce it to 15min with these functions:

``````def model_bnn(p, t, w1, b1, w2, b2, w3, b3):
X = jnp.hstack([p,t])
z1 = nonlin(jnp.matmul(X, w1) + jnp.transpose(b1))
z2 = nonlin(jnp.matmul(z1, w2) + jnp.transpose(b2))
z3 = jnp.matmul(z2, w3) + jnp.transpose(b3)
return jnp.reshape(z3, ()) # scalar

def model_bpinn(p, t, Y, F, D_H, u_sigma=None, f_sigma=None, sigma_w=1):

m = 0.15
d = 0.15
B = 0.2

D_X, D_Y = 2, 1

# sample first layer
w1 = numpyro.sample("w1", dist.Normal(jnp.zeros((D_X, D_H)), sigma_w*jnp.ones((D_X, D_H))))  # D_X D_H
b1 = numpyro.sample("b1", dist.Normal(jnp.zeros((D_H, 1)), sigma_w*jnp.ones((D_H, 1))))  # D_H 1
# sample second layer
w2 = numpyro.sample("w2", dist.Normal(jnp.zeros((D_H, D_H)), sigma_w*jnp.ones((D_H, D_H))))  # D_H D_H
b2 = numpyro.sample("b2", dist.Normal(jnp.zeros((D_H, 1)), sigma_w*jnp.ones((D_H, 1))))  # D_H 1
# sample final layer
w3 = numpyro.sample("w3", dist.Normal(jnp.zeros((D_H, D_Y)), sigma_w*jnp.ones((D_H, D_Y))))  # D_H D_Y
b3 = numpyro.sample("b3", dist.Normal(jnp.zeros((D_Y, 1)), sigma_w*jnp.ones((D_Y, 1))))  # D_H 1

u_mu, dudt = mu_grad(p.squeeze(), t.squeeze(), w1, b1, w2, b2, w3, b3)
dudtt = second_grad(p.squeeze(), t.squeeze(), w1, b1, w2, b2, w3, b3)

# prior on the observation noise
if u_sigma is None:
prec_u = numpyro.sample("prec_u", dist.Gamma(3.0, 1.0))
u_sigma = 1.0 / jnp.sqrt(prec_u)
if f_sigma is None:
prec_f = numpyro.sample("prec_f", dist.Gamma(3.0, 1.0))
f_sigma = 1.0 / jnp.sqrt(prec_f)

# observe data
with numpyro.plate('observations', p.shape[0]):
u_hat = numpyro.sample("Y", dist.Normal(u_mu.reshape(-1,1), u_sigma), obs=Y)
f_mu = m * dudtt.reshape(-1,1) + d * dudt.reshape(-1,1) + B * jnp.sin(u_hat) - p # Forcing physics-term, always=0
f_hat = numpyro.sample("F", dist.Normal(f_mu, f_sigma), obs=F)

return u_mu, f_mu
``````

The `jnp.reshape(z3, ())` inside the BNN was very annoying, JAX was complaining about the (1,) vector because it wants a scalar to take the grad.

Is there a way to tell numpyro.sample(“Y”) to get the observation sometimes and other times to sample inside the vectorized plate? I was trying to pass a vector with values and some np.nan (to act as None) but it didn’t do the trick.

You can use mask at site “Y” as in this thread (need numpyro `master` branch) or do Bayesian imputation as in this tutorial.

Trying to reproduce a minimal example to understand `mask` but it always returns the observation:

``````def masked_model(x, y, mask):
with numpyro.plate('data', len(x)):
Y = numpyro.sample("Y", dist.Normal(x, 1.), obs=y)
return Y

x = jnp.array([10., 10., 10., 10., 10.])
y = jnp.array([-1., -1., -1., -1., -1.])
mask = jnp.array([True, False, True, False, True])

# returns DeviceArray([-1., -1., -1., -1., -1.], dtype=float32)

``````

It is only used to mask out the probability (i.e. ignore probability of some nan terms in your case). If you want to mask the value, you can use `np.where` or something like that directly on `Y` or `y`, no need to use mask handler. See the docs.

Okay I think the mask should work perfectly then, so in case of this non-sense model:

``````def masked_model(x, y, z, mask):
A = numpyro.sample("A", dist.Normal(0., 1.).expand([len(x)]))
B = A * x
with numpyro.plate('data', len(y)):
Y = numpyro.sample("Y", dist.Normal(B, 1.), obs=y)
Z = numpyro.sample("Z", dist.Normal(B, 1.), obs=z)
return Y, Z

x = jnp.array([10., 10.])
y = jnp.array([-1., -1])
z = jnp.array([-1., -1])
Then the posterior will be `p(theta | Y, Z) = p(Y | theta) p(Z | theta) p (theta)` with likelihood `p(Y | theta)` appearing in the equation only for the cases where the mask is True, right?