OK thanks @fehiepsi, I’ve made some progress again. Firstly, it seems that the infinite loop issue with NUTS was resolved by turning progress_bar=True on. Strangely…
Then regarding the minimal example to reproduce the forward mode differentiation error. The last line should cause the error.
import jax
import jax.numpy as jnp
from jax import ops, random
import numpy as np
from jax.experimental.ode import odeint
import numpyro
import numpyro.distributions as dist
from numpyro.infer.autoguide import AutoLaplaceApproximation
from numpyro.infer import SVI, Trace_ELBO
def dy_dt(z, t, A):
return jnp.matmul(A, z)
def model(data, ts, ind_mat):
theta = numpyro.sample("theta", dist.Normal(loc=jnp.array([0.] * 2),
scale=jnp.array([1.0]) * 2),)
A = jnp.array(np.zeros([ind_mat.shape[0], ind_mat.shape[1]]))
A = ops.index_update(A, ind_mat, theta)
y_pred = odeint(dy_dt, fake_data[0], ts, A, rtol=1e-6, atol=1e-5, mxstep=1000)
sigma = numpyro.sample("sigma", dist.LogNormal(-1, 1).expand([2]))
numpyro.sample("y_pred", dist.Normal(y_pred[1:], sigma), obs=fake_data[1:])
def get_MAP(data, ts, ind_mat):
optimizer = numpyro.optim.Minimize(method="BFGS")
guide = AutoLaplaceApproximation(model)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
init_state = svi.init(random.PRNGKey(0), data, ts, ind_mat)
optimal_state, loss = svi.update(init_state, data, ts, ind_mat)
params = svi.get_params(optimal_state)
return params, guide, svi, optimal_state, loss
ts = jnp.array(list(np.linspace(0, 240, 81)))
y_init = jnp.array([-0.06, -0.01])
K = jnp.array([[ 0., 0.01],
[ 0.01, 0.]])
ind_mat = (K != 0)
fake_data = odeint(dy_dt, y_init, ts, K)
params, guide, svi, optimal_state, loss = get_MAP(fake_data, ts, ind_mat)
print(params)
print(guide.quantiles(params, 0.91))