Strange behaviour when using `odeint` and "longer" simulation times

I have basically recreated an example from PyMC3 in numpyro.
It works just fine as long as the simulation time is small enough. If I let the simulation run longer, first more and more divergences start showing up and the parameter sigma increases in mean and std and finally the results look completely wild.
Setting t1 = 45 in line 28 results in 8000/8000 divergences. Setting it to t1 = 35 has about 5300/8000 divergences. Could anyone give me a clue to what is going on and what I could do about it?
Thank you very much in advance. So far it has been a real pleasure working with numpyro!

Here is a more or less minimal working example

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from jax.config import config
config.update('jax_enable_x64', True)

import matplotlib.pyplot as pt
import arviz as az

from jax.experimental.ode import odeint
import jax.numpy as jnp
from jax.random import PRNGKey

import numpyro
device = 'cpu'
num_chains = 4
numpyro.set_platform(device)
numpyro.set_host_device_count(num_chains)
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive

import numpy as np


y0 = [0.99, 0.01]
params_true = jnp.asarray((2.0, 1.0))

t0 = 0.; t1 = 15.; dt = 0.25
t = jnp.arange(0, t1, dt)


def SIR(y, t, params):
    S = y[0]
    I = y[1]
    beta, gamma = (
        params[..., 0],
        params[..., 1],
    )
    dSdt = -beta * S * I
    dIdt = beta * S * I - gamma * I
    return jnp.stack([dSdt, dIdt])

def model(SIR, y0, t1, dt, y_obs=None):
    t = jnp.arange(float(dt), float(t1), float(dt))
    
    bounded_dist = dist.HalfNormal(y0[1], 0.2)
    bounded_dist.support = dist.constraints.interval(0., 1.)
    I0 = numpyro.sample("I0", bounded_dist)
    y0 = jnp.asarray((1.0 - I0, I0))
    
    params = numpyro.sample(
        'params',
        dist.TruncatedNormal(
            low=0.0,
            loc=jnp.array([3., 0.5]),
            scale=jnp.array([1., 1.]),
        ),
    )
    sigma = numpyro.sample('σ', dist.HalfCauchy(1))
    y = odeint(SIR, y0, t, params, rtol=1e-6, atol=1e-5, mxstep=1000)

    numpyro.sample('y', dist.LogNormal(jnp.log(y[:,1]), sigma), obs=y_obs)

y_true = np.array(
    odeint(SIR, y0, t, params_true, rtol=1e-6, atol=1e-5, mxstep=1000)
).T
rng = np.random.RandomState(seed=20010516)
y_obs = np.nan_to_num(
    rng.lognormal(mean=np.log(y_true[1:, 1]), sigma=0.3), nan=1e-8
)
y_obs = y_obs.clip(0.0001)

mcmc = MCMC(
    NUTS(model, dense_mass=True),
    num_warmup=1000,
    num_samples=4000,
    num_chains=4,
)
# push the inference button
mcmc.run(PRNGKey(1), SIR=SIR, y0=y0, t1=t1, dt=dt, y_obs=y_obs)
mcmc.print_summary()

pred = Predictive(
    model,mcmc.get_samples())(PRNGKey(2), SIR=SIR, y0=y0, t1=t1, dt=dt
)["y"]
mu = jnp.mean(pred, 0)
pi = [jnp.percentile(pred, 2.5, 0), jnp.percentile(pred, 97.5, 0)]

pt.scatter(t[1:], y_obs, color='C1', label='Observations')
pt.plot(t[1:], mu, color='C1', label='Prediction')
pt.fill_between(t[1:], pi[0], pi[1], color='C1', alpha=0.25)
pt.plot(t, y_true[:,1], linestyle='--', color='r', label='Truth')
pt.legend()

data = az.from_numpyro(mcmc)
az.plot_posterior(
    data,
    round_to=2,
    hdi_prob=0.95,
)

pt.show()

have you tried lower tolerances and/or 64 bit precision?

Hi martinjankowiak,
thanks for the quick reply!
I have set the precision to 64bit (5th line of the example), but I’ll use lower tolerances tomorrow and see what happens. If that would really help, I’d be very curious to understand how this only plays a role for longer simulation times.

Hello again,
lowering the tolerance and also increasing mxstep=jnp.inf seems to improve the results a little bit, but fundamentally, the problem is not solved, I just have to increase the simulation time a bit more and the results look similar bad as before.

Do you have any other suggestions?

well afaik this is more or less expected behavior for odes. long time integration is hard. we know celestial mechanics very well and can predict where the earth will be in 10k years very accurately; however, afaik we can’t predict where the earth will be in 10 million years (despite the fact that this problem has special symplectic structure).

so afaik your only choices are:

  • use a better integrator (e.g. exploiting any special structure your problem may have; probably your problem doesn’t have such structure although i don’t know for sure)
  • decrease tolerances etc
  • use tighter priors so that odeint doesn’t face particularly wacky parameter settings
1 Like

also you may find this helpful