Differentiate a neural network wrt inputs

I am interested in reproducing the paper: [2003.06097] B-PINNs: Bayesian Physics-Informed Neural Networks for Forward and Inverse PDE Problems with Noisy Data

Here, the likelihood term also contains the derivative of the neural network (see eq. (4) and eq. (5) in the above paper)

I am currently working with the Bayesian neural network example which uses numpyro. I would like to know if it is possible to differentiate the neural network w.r.t its input within the model?

def model(X, Y, D_H, D_Y=1):
    N, D_X = X.shape

    w1 = numpyro.sample("w1", dist.Normal(jnp.zeros((D_X, D_H)), jnp.ones((D_X, D_H))))
    b1 = numpyro.sample("b1", dist.Normal(jnp.zeros((D_H,)), jnp.ones((D_H,))))

    w2 = numpyro.sample("w2", dist.Normal(jnp.zeros((D_H, D_H)), jnp.ones((D_H, D_H))))
    b2 = numpyro.sample("b2", dist.Normal(jnp.zeros((D_H,)), jnp.ones((D_H,))))

    w3 = numpyro.sample("w3", dist.Normal(jnp.zeros((D_H, D_Y)), jnp.ones((D_H, D_Y))))
    b3 = numpyro.sample("b4", dist.Normal(jnp.zeros((D_Y,)), jnp.ones((D_Y,))))

    z1 = nonlin(jnp.matmul(X, w1) + b1)
    z2 = nonlin(jnp.matmul(z1, w2) + b2)
    z3 = jnp.matmul(z2, w3) + b3

    # compute derivate of z3 wrt X
    # dz3_dX = ?

If this is not possible with numpyro, is it possible with pyro + torch?

should be possible. probably easier in jax/numpyro. can you be more specific?

1 Like