I have basically recreated an example from PyMC3 in numpyro.
It works just fine as long as the simulation time is small enough. If I let the simulation run longer, first more and more divergences start showing up and the parameter sigma increases in mean and std and finally the results look completely wild.
Setting t1 = 45
in line 28 results in 8000/8000 divergences. Setting it to t1 = 35
has about 5300/8000 divergences. Could anyone give me a clue to what is going on and what I could do about it?
Thank you very much in advance. So far it has been a real pleasure working with numpyro!
Here is a more or less minimal working example
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from jax.config import config
config.update('jax_enable_x64', True)
import matplotlib.pyplot as pt
import arviz as az
from jax.experimental.ode import odeint
import jax.numpy as jnp
from jax.random import PRNGKey
import numpyro
device = 'cpu'
num_chains = 4
numpyro.set_platform(device)
numpyro.set_host_device_count(num_chains)
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive
import numpy as np
y0 = [0.99, 0.01]
params_true = jnp.asarray((2.0, 1.0))
t0 = 0.; t1 = 15.; dt = 0.25
t = jnp.arange(0, t1, dt)
def SIR(y, t, params):
S = y[0]
I = y[1]
beta, gamma = (
params[..., 0],
params[..., 1],
)
dSdt = -beta * S * I
dIdt = beta * S * I - gamma * I
return jnp.stack([dSdt, dIdt])
def model(SIR, y0, t1, dt, y_obs=None):
t = jnp.arange(float(dt), float(t1), float(dt))
bounded_dist = dist.HalfNormal(y0[1], 0.2)
bounded_dist.support = dist.constraints.interval(0., 1.)
I0 = numpyro.sample("I0", bounded_dist)
y0 = jnp.asarray((1.0 - I0, I0))
params = numpyro.sample(
'params',
dist.TruncatedNormal(
low=0.0,
loc=jnp.array([3., 0.5]),
scale=jnp.array([1., 1.]),
),
)
sigma = numpyro.sample('σ', dist.HalfCauchy(1))
y = odeint(SIR, y0, t, params, rtol=1e-6, atol=1e-5, mxstep=1000)
numpyro.sample('y', dist.LogNormal(jnp.log(y[:,1]), sigma), obs=y_obs)
y_true = np.array(
odeint(SIR, y0, t, params_true, rtol=1e-6, atol=1e-5, mxstep=1000)
).T
rng = np.random.RandomState(seed=20010516)
y_obs = np.nan_to_num(
rng.lognormal(mean=np.log(y_true[1:, 1]), sigma=0.3), nan=1e-8
)
y_obs = y_obs.clip(0.0001)
mcmc = MCMC(
NUTS(model, dense_mass=True),
num_warmup=1000,
num_samples=4000,
num_chains=4,
)
# push the inference button
mcmc.run(PRNGKey(1), SIR=SIR, y0=y0, t1=t1, dt=dt, y_obs=y_obs)
mcmc.print_summary()
pred = Predictive(
model,mcmc.get_samples())(PRNGKey(2), SIR=SIR, y0=y0, t1=t1, dt=dt
)["y"]
mu = jnp.mean(pred, 0)
pi = [jnp.percentile(pred, 2.5, 0), jnp.percentile(pred, 97.5, 0)]
pt.scatter(t[1:], y_obs, color='C1', label='Observations')
pt.plot(t[1:], mu, color='C1', label='Prediction')
pt.fill_between(t[1:], pi[0], pi[1], color='C1', alpha=0.25)
pt.plot(t, y_true[:,1], linestyle='--', color='r', label='Truth')
pt.legend()
data = az.from_numpyro(mcmc)
az.plot_posterior(
data,
round_to=2,
hdi_prob=0.95,
)
pt.show()