Gradients with JAX using HMC

Hi!

I am trying to implement the HMC BPINN with Numpyro, but I am having some troubles with JAX’s functional programming paradigm. Since I need the gradients to compute the mean, I get an error of declaring two times a numpyro.sample variable but I can’t get my head around it.

This is the implementation so far:

def nonlin(x):
    return jnp.tanh(x)

def model_bnn(p, t, D_H, sigma_w=1):

    X = jnp.concatenate((p, t), axis=1)
    
    D_X, D_Y = X.shape[1], 1
    
    # sample first layer (we put unit normal priors on all weights)
    w1 = numpyro.sample("w1", dist.Normal(jnp.zeros((D_X, D_H)), sigma_w*jnp.ones((D_X, D_H))))  # D_X D_H
    b1 = numpyro.sample("b1", dist.Normal(jnp.zeros((D_H, 1)), sigma_w*jnp.ones((D_H, 1))))  # D_H 1
    z1 = nonlin(jnp.matmul(X, w1) + jnp.transpose(b1))   # N D_H  <= first layer of activations

    # sample second layer
    w2 = numpyro.sample("w2", dist.Normal(jnp.zeros((D_H, D_H)), sigma_w*jnp.ones((D_H, D_H))))  # D_H D_H
    b2 = numpyro.sample("b2", dist.Normal(jnp.zeros((D_H, 1)), sigma_w*jnp.ones((D_H, 1))))  # D_H 1
    z2 = nonlin(jnp.matmul(z1, w2) + jnp.transpose(b2))  # N D_H  <= second layer of activations

    # sample final layer of weights and neural network output
    w3 = numpyro.sample("w3", dist.Normal(jnp.zeros((D_H, D_Y)), sigma_w*jnp.ones((D_H, D_Y))))  # D_H D_Y
    b3 = numpyro.sample("b3", dist.Normal(jnp.zeros((D_Y, 1)), sigma_w*jnp.ones((D_Y, 1))))  # D_H 1
    z3 = jnp.matmul(z2, w3) + jnp.transpose(b3)  # N D_Y  <= output of the neural network

    return z3

grad_bnn = grad(model_bnn, argnums=1)

def model_bpinn(p, t, Y, F, D_H, u_sigma=None, f_sigma=None, sigma_w=1):

    m = 0.15
    d = 0.15
    B = 0.2

    u_mu = model_bnn(p, t, D_H, sigma_w)
    dudt = grad_bnn(p, t, D_H, sigma_w)
    dudtt = grad_bnn(p, dudt, D_H, sigma_w)
    
    # prior on the observation noise
    if u_sigma is None:
        prec_u = numpyro.sample("prec_u", dist.Gamma(3.0, 1.0))
        u_sigma = 1.0 / jnp.sqrt(prec_u)
    if f_sigma is None:
        prec_f = numpyro.sample("prec_f", dist.Gamma(3.0, 1.0))
        f_sigma = 1.0 / jnp.sqrt(prec_f)

    # observe data
    with numpyro.plate('observations', D_X):
        u_hat = numpyro.sample("Y", dist.Normal(u_mu, u_sigma), obs=Y)
        f_mu = m * dudtt + d * dudt + B * jnp.sin(u_hat) - p # Forcing physics-term, always=0
        f_hat = numpyro.sample("F", dist.Normal(f_mu, f_sigma), obs=F)
    
    return u_hat, f_hat

Any help is appreciated!

i believe you need to pull your sample statements outside of model_bnn and then pass them in as args. the problem is that you effectively call model_bnn three times and thus duplicate those sample statements. let us know if that doesn’t work!

Thanks, worked perfectly! Now, is there a cleaner way to get the first and second gradients compared to my approach? The MCMC sampling is now taking hours, compared to a few minutes if not computing the first and second derivative. First I tried with grad(model)(args) and grad(grad(model))(args) but given that the output of the BNN is a (n,1) vector, I had to use jacfwd and the hessian.

Notice that p and t are (n,1) vectors, respectively.

def nonlin(x):
    return jnp.tanh(x)

def model_bnn(p, t, w1, b1, w2, b2, w3, b3):
    X = jnp.concatenate((p, t), axis=1)
    z1 = nonlin(jnp.matmul(X, w1) + jnp.transpose(b1))
    z2 = nonlin(jnp.matmul(z1, w2) + jnp.transpose(b2))
    z3 = jnp.matmul(z2, w3) + jnp.transpose(b3)
    return z3.squeeze()

first_grad = jacfwd(model_bnn, argnums=1)
second_grad = hessian(model_bnn, argnums=1)

def model_bpinn(p, t, Y, F, D_H, u_sigma=None, f_sigma=None, sigma_w=1):

    m = 0.15
    d = 0.15
    B = 0.2
    
    D_X, D_Y = 2, 1
    
    # sample first layer (we put unit normal priors on all weights)
    w1 = numpyro.sample("w1", dist.Normal(jnp.zeros((D_X, D_H)), sigma_w*jnp.ones((D_X, D_H))))  # D_X D_H
    b1 = numpyro.sample("b1", dist.Normal(jnp.zeros((D_H, 1)), sigma_w*jnp.ones((D_H, 1))))  # D_H 1
    # sample second layer
    w2 = numpyro.sample("w2", dist.Normal(jnp.zeros((D_H, D_H)), sigma_w*jnp.ones((D_H, D_H))))  # D_H D_H
    b2 = numpyro.sample("b2", dist.Normal(jnp.zeros((D_H, 1)), sigma_w*jnp.ones((D_H, 1))))  # D_H 1
    # sample final layer of weights and neural network output
    w3 = numpyro.sample("w3", dist.Normal(jnp.zeros((D_H, D_Y)), sigma_w*jnp.ones((D_H, D_Y))))  # D_H D_Y
    b3 = numpyro.sample("b3", dist.Normal(jnp.zeros((D_Y, 1)), sigma_w*jnp.ones((D_Y, 1))))  # D_H 1

    u_mu = model_bnn(p, t, w1, b1, w2, b2, w3, b3).reshape(-1,1)
    dudt = jnp.diagonal(first_grad(p, t, w1, b1, w2, b2, w3, b3).squeeze()).reshape(-1,1)
    dudtt = jnp.diagonal(jnp.diagonal(second_grad(p, t, w1, b1, w2, b2, w3, b3).squeeze())).reshape(-1,1)

    # prior on the observation noise
    if u_sigma is None:
        prec_u = numpyro.sample("prec_u", dist.Gamma(3.0, 1.0))
        u_sigma = 1.0 / jnp.sqrt(prec_u)
    if f_sigma is None:
        prec_f = numpyro.sample("prec_f", dist.Gamma(3.0, 1.0))
        f_sigma = 1.0 / jnp.sqrt(prec_f)

    # observe data
    with numpyro.plate('observations', p.shape[0]):
        u_hat = numpyro.sample("Y", dist.Normal(u_mu, u_sigma), obs=Y)
        f_mu = m * dudtt + d * dudt + B * jnp.sin(u_hat) - p # Forcing physics-term, always=0
        f_hat = numpyro.sample("F", dist.Normal(f_mu, f_sigma), obs=F)
    
    return u_mu, f_mu

Thanks!

i don’t think so.

note that NUTS will run model_bpinn and compute the gradient of the resulting log density dozens of times per iteration. this is inherently expensive unless all your dimensions are very small.

I think you can use vmap(grad(f))(x) to compute a batch of grads of f given x. Similar for the second derivatives vmap(grad(grad(f)))(x).

1 Like

That was a beast optimization, managed to reduce it to 15min with these functions:

def model_bnn(p, t, w1, b1, w2, b2, w3, b3):
    X = jnp.hstack([p,t])
    z1 = nonlin(jnp.matmul(X, w1) + jnp.transpose(b1))
    z2 = nonlin(jnp.matmul(z1, w2) + jnp.transpose(b2))
    z3 = jnp.matmul(z2, w3) + jnp.transpose(b3)
    return jnp.reshape(z3, ()) # scalar

mu_grad = vmap(value_and_grad(model_bnn, argnums=1), (0,0,None,None,None,None,None,None),0)
second_grad = vmap(grad(grad(model_bnn, argnums=1), argnums=1), (0,0,None,None,None,None,None,None),0)

def model_bpinn(p, t, Y, F, D_H, u_sigma=None, f_sigma=None, sigma_w=1):

    m = 0.15
    d = 0.15
    B = 0.2
    
    D_X, D_Y = 2, 1
    
    # sample first layer
    w1 = numpyro.sample("w1", dist.Normal(jnp.zeros((D_X, D_H)), sigma_w*jnp.ones((D_X, D_H))))  # D_X D_H
    b1 = numpyro.sample("b1", dist.Normal(jnp.zeros((D_H, 1)), sigma_w*jnp.ones((D_H, 1))))  # D_H 1
    # sample second layer
    w2 = numpyro.sample("w2", dist.Normal(jnp.zeros((D_H, D_H)), sigma_w*jnp.ones((D_H, D_H))))  # D_H D_H
    b2 = numpyro.sample("b2", dist.Normal(jnp.zeros((D_H, 1)), sigma_w*jnp.ones((D_H, 1))))  # D_H 1
    # sample final layer
    w3 = numpyro.sample("w3", dist.Normal(jnp.zeros((D_H, D_Y)), sigma_w*jnp.ones((D_H, D_Y))))  # D_H D_Y
    b3 = numpyro.sample("b3", dist.Normal(jnp.zeros((D_Y, 1)), sigma_w*jnp.ones((D_Y, 1))))  # D_H 1

    u_mu, dudt = mu_grad(p.squeeze(), t.squeeze(), w1, b1, w2, b2, w3, b3)
    dudtt = second_grad(p.squeeze(), t.squeeze(), w1, b1, w2, b2, w3, b3)
    
    # prior on the observation noise
    if u_sigma is None:
        prec_u = numpyro.sample("prec_u", dist.Gamma(3.0, 1.0))
        u_sigma = 1.0 / jnp.sqrt(prec_u)
    if f_sigma is None:
        prec_f = numpyro.sample("prec_f", dist.Gamma(3.0, 1.0))
        f_sigma = 1.0 / jnp.sqrt(prec_f)

    # observe data
    with numpyro.plate('observations', p.shape[0]):
        u_hat = numpyro.sample("Y", dist.Normal(u_mu.reshape(-1,1), u_sigma), obs=Y)
        f_mu = m * dudtt.reshape(-1,1) + d * dudt.reshape(-1,1) + B * jnp.sin(u_hat) - p # Forcing physics-term, always=0
        f_hat = numpyro.sample("F", dist.Normal(f_mu, f_sigma), obs=F)
    
    return u_mu, f_mu

The jnp.reshape(z3, ()) inside the BNN was very annoying, JAX was complaining about the (1,) vector because it wants a scalar to take the grad.

Is there a way to tell numpyro.sample(“Y”) to get the observation sometimes and other times to sample inside the vectorized plate? I was trying to pass a vector with values and some np.nan (to act as None) but it didn’t do the trick.

You can use mask at site “Y” as in this thread (need numpyro master branch) or do Bayesian imputation as in this tutorial.

Trying to reproduce a minimal example to understand mask but it always returns the observation:

def masked_model(x, y, mask):
    with numpyro.plate('data', len(x)):
        with handlers.mask(mask=mask):
            Y = numpyro.sample("Y", dist.Normal(x, 1.), obs=y)
    return Y

x = jnp.array([10., 10., 10., 10., 10.])
y = jnp.array([-1., -1., -1., -1., -1.])
mask = jnp.array([True, False, True, False, True])

handlers.seed(masked_model, rng_seed=0)(x, y, mask)
# returns DeviceArray([-1., -1., -1., -1., -1.], dtype=float32)

It is only used to mask out the probability (i.e. ignore probability of some nan terms in your case). If you want to mask the value, you can use np.where or something like that directly on Y or y, no need to use mask handler. See the docs.

Okay I think the mask should work perfectly then, so in case of this non-sense model:

def masked_model(x, y, z, mask):
    A = numpyro.sample("A", dist.Normal(0., 1.).expand([len(x)]))
    B = A * x
    with numpyro.plate('data', len(y)):
        with handlers.mask(mask=mask):
            Y = numpyro.sample("Y", dist.Normal(B, 1.), obs=y)
        Z = numpyro.sample("Z", dist.Normal(B, 1.), obs=z)
    return Y, Z

x = jnp.array([10., 10.])
y = jnp.array([-1., -1])
z = jnp.array([-1., -1])
mask = jnp.array([True, False])

handlers.seed(masked_model, rng_seed=0)(x, y, z, mask)

Then the posterior will be p(theta | Y, Z) = p(Y | theta) p(Z | theta) p (theta) with likelihood p(Y | theta) appearing in the equation only for the cases where the mask is True, right?

Yes, you are right.

1 Like