I have basically recreated an example from PyMC3 in numpyro.

It works just fine as long as the simulation time is small enough. If I let the simulation run longer, first more and more divergences start showing up and the parameter sigma increases in mean and std and finally the results look completely wild.

Setting `t1 = 45`

in line 28 results in 8000/8000 divergences. Setting it to `t1 = 35`

has about 5300/8000 divergences. Could anyone give me a clue to what is going on and what I could do about it?

Thank you very much in advance. So far it has been a real pleasure working with numpyro!

Here is a more or less minimal working example

```
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from jax.config import config
config.update('jax_enable_x64', True)
import matplotlib.pyplot as pt
import arviz as az
from jax.experimental.ode import odeint
import jax.numpy as jnp
from jax.random import PRNGKey
import numpyro
device = 'cpu'
num_chains = 4
numpyro.set_platform(device)
numpyro.set_host_device_count(num_chains)
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive
import numpy as np
y0 = [0.99, 0.01]
params_true = jnp.asarray((2.0, 1.0))
t0 = 0.; t1 = 15.; dt = 0.25
t = jnp.arange(0, t1, dt)
def SIR(y, t, params):
S = y[0]
I = y[1]
beta, gamma = (
params[..., 0],
params[..., 1],
)
dSdt = -beta * S * I
dIdt = beta * S * I - gamma * I
return jnp.stack([dSdt, dIdt])
def model(SIR, y0, t1, dt, y_obs=None):
t = jnp.arange(float(dt), float(t1), float(dt))
bounded_dist = dist.HalfNormal(y0[1], 0.2)
bounded_dist.support = dist.constraints.interval(0., 1.)
I0 = numpyro.sample("I0", bounded_dist)
y0 = jnp.asarray((1.0 - I0, I0))
params = numpyro.sample(
'params',
dist.TruncatedNormal(
low=0.0,
loc=jnp.array([3., 0.5]),
scale=jnp.array([1., 1.]),
),
)
sigma = numpyro.sample('σ', dist.HalfCauchy(1))
y = odeint(SIR, y0, t, params, rtol=1e-6, atol=1e-5, mxstep=1000)
numpyro.sample('y', dist.LogNormal(jnp.log(y[:,1]), sigma), obs=y_obs)
y_true = np.array(
odeint(SIR, y0, t, params_true, rtol=1e-6, atol=1e-5, mxstep=1000)
).T
rng = np.random.RandomState(seed=20010516)
y_obs = np.nan_to_num(
rng.lognormal(mean=np.log(y_true[1:, 1]), sigma=0.3), nan=1e-8
)
y_obs = y_obs.clip(0.0001)
mcmc = MCMC(
NUTS(model, dense_mass=True),
num_warmup=1000,
num_samples=4000,
num_chains=4,
)
# push the inference button
mcmc.run(PRNGKey(1), SIR=SIR, y0=y0, t1=t1, dt=dt, y_obs=y_obs)
mcmc.print_summary()
pred = Predictive(
model,mcmc.get_samples())(PRNGKey(2), SIR=SIR, y0=y0, t1=t1, dt=dt
)["y"]
mu = jnp.mean(pred, 0)
pi = [jnp.percentile(pred, 2.5, 0), jnp.percentile(pred, 97.5, 0)]
pt.scatter(t[1:], y_obs, color='C1', label='Observations')
pt.plot(t[1:], mu, color='C1', label='Prediction')
pt.fill_between(t[1:], pi[0], pi[1], color='C1', alpha=0.25)
pt.plot(t, y_true[:,1], linestyle='--', color='r', label='Truth')
pt.legend()
data = az.from_numpyro(mcmc)
az.plot_posterior(
data,
round_to=2,
hdi_prob=0.95,
)
pt.show()
```