Slow likelihood/gradient evaluation in NUTS sampling from ODE model with large data

Hi NumPyro community,

I’m working on sampling parameter distributions of an ODE model.
with few data (n=5), the model samples very fast on CPU (< 1min), but once I increase the amount of data (n=300). The gradient evaluations take a long time. At some point this becomes infeasible.

So my question is: Have I reached the limit here? Or can I specify the probability model more efficiently, or parallelize the likelihood/gradient evaluation somehow?

Below you find a minimal model with a fast run and a slow run.

I’d appreciate any leads, ideas, or clear answers :smiley:

Best,
Florian

Dependencies:

python 3.11.5
jax 0.4.21
numpyro 0.13.2
diffrax 0.4.1
arviz 0.15.1

Minimal Example:


from functools import partial

import arviz as az

import jax
import jax.numpy as jnp

import diffrax
from diffrax import (
    diffeqsolve, 
    Dopri5, 
    ODETerm, 
    SaveAt, 
    PIDController, 
    RecursiveCheckpointAdjoint
)

import numpyro
from numpyro.distributions import (
    LogNormal,
    Binomial
)

from numpyro.infer import (
    init_to_median,
    MCMC,
    NUTS,
)


def tktd_guts_minimal(t, X, k_i, k_e, k_a, k_r):
    """ODE Model        
    """
    Ce, Ci, D = X

    dCe_dt = 0
    dCi_dt = k_i * Ce - k_e * Ci
    dD_dt = Ci * k_a - D * k_r

    return jnp.array([dCe_dt, dCi_dt, dD_dt])


def survival_jax(t, damage, z, kk, h_b):
    """
    survival probability derived from hazard by using the trapezoidal rule
    """
    hazard = kk * jnp.where(damage - z < 0, 0, damage - z) + h_b
    H = jnp.array([jax.scipy.integrate.trapezoid(
        hazard[:i+1], t[:i+1], axis=0) for i in range(len(t))]
    )
    S = jnp.exp(-H)

    return S


@jax.jit
def odesolve_batch(y0, time, theta):
    """small wrapper around odesolve to vectorize computations of multiple samples
    efficiently"""
    eval = jax.vmap(
        partial(odesolve, theta=theta, time=time), 
        in_axes=(0,)
    )

    res = eval(y0)
    return res


@jax.jit
def odesolve(y0, time, theta):
    """the deterministic solver and the survival function"""
    f = lambda t, y, theta: tktd_guts_minimal(t, y, *theta)
    
    term = ODETerm(f)
    solver = Dopri5()
    saveat = SaveAt(ts=time)
    stepsize_controller = PIDController(rtol=1e-6, atol=1e-7)

    sol = diffeqsolve(
        terms=term, 
        solver=solver, 
        t0=time.min(), 
        t1=time.max(), 
        dt0=0.1, 
        y0=y0, 
        args=theta[:4], 
        saveat=saveat, 
        stepsize_controller=stepsize_controller,
        adjoint=RecursiveCheckpointAdjoint(),
        max_steps=10**6,
        throw=False,   
    )
    
    damage=sol.ys[:, 2]
    surv = survival_jax(t=time, damage=damage, z=theta[4], kk=theta[5], h_b=theta[6])
    
    return jnp.column_stack([sol.ys, surv])


@jax.jit
def add_noise_jax(Y, n, key, sigma_ce, sigma_ci, sigma_d):
    """convenience function to add noise to a dataset"""
    lognorm = jax.random.lognormal
    binom = jax.random.binomial
    key, k0, k1, k2, k3 = jax.random.split(key, 5)

    # generate noise vectors to be multiplied 
    y0_noise = lognorm(k0, sigma=sigma_ce, shape=Y[:, :, 0].shape)
    y1_noise = lognorm(k1, sigma=sigma_ci, shape=Y[:, :, 1].shape)
    y2_noise = lognorm(k2, sigma=sigma_d, shape=Y[:, :, 2].shape)

    # do elementwise matrix multiplication along the j-axis of the matrix
    y0 = jnp.einsum("jm,jm -> jm", y0_noise, Y[:, :, 0])
    y1 = jnp.einsum("jm,jm -> jm", y1_noise, Y[:, :, 1])
    y2 = jnp.einsum("jm,jm -> jm", y2_noise, Y[:, :, 2])
    
    y3 = binom(k3, p=Y[:, :, 3], n=n)
    
    # stack along the last axis
    return jnp.stack([y0, y1, y2, y3], axis=2)


def prob_model(solver, time, y0, obs=None, masks=None):
    """The probabilistic model to sample from the posterior parameter 
    distributions of the ODE model"""
    # parameters of the deterministic model
    k_i = numpyro.sample("k_i", LogNormal(jnp.log(0.1), scale=1))
    k_e = numpyro.sample("k_e", LogNormal(jnp.log(0.05), scale=1))
    k_a = numpyro.sample("k_a", LogNormal(jnp.log(0.2), scale=1))
    k_r = numpyro.sample("k_r", LogNormal(jnp.log(0.01), scale=1))
    z = numpyro.sample("z", LogNormal(jnp.log(0.01), scale=1))
    kk = numpyro.sample("kk", LogNormal(jnp.log(0.01), scale=1))
    h_b = numpyro.sample("h_b", LogNormal(jnp.log(0.01), scale=1))
    eps = 1e-8

    # parameters of the error model
    # currently the scale parameter of the error distribution is fixed (scale=0.1)
    # this parameter can also be drawn from a random variable. For instance,
    # from a half-normal or a uniform(0, 100) or something positive constrained.

    # parameters must be in the correct order, or better provided as a dictionary
    # and extracted accordingly
    theta = (k_i, k_e, k_a, k_r, z, kk, h_b)

    # compute the deterministic model and store the variables
    res = solver(theta=theta, time=time, y0=y0)

    Ce = numpyro.deterministic("Ce", res[:, :, 0])
    Ci = numpyro.deterministic("Ci", res[:, :, 1])
    D = numpyro.deterministic("D", res[:, :, 2])
    L = numpyro.deterministic("L", res[:, :, 3])

    # calculate the likelihood only if observations ar specified
    if obs is not None:
        numpyro.sample("Ce_obs", LogNormal(jnp.log(Ce + eps), scale=0.1), obs=obs[:, :, 0] + eps)
        numpyro.sample("Ci_obs", LogNormal(jnp.log(Ci + eps), scale=0.1), obs=obs[:, :, 1] + eps)
        numpyro.sample("D_obs", LogNormal(jnp.log(D + eps), scale=0.1), obs=obs[:, :, 2] + eps)
        numpyro.sample("L_obs", Binomial(total_count=10, probs=L), obs=obs[:, :, 3])


def run_inference(key, solver, time, y0, obs=None, masks=None):
    """run MCMC on NUTS kernel and return arviz.InferenceData"""
    # initialize the probability model with constants
    model = partial(prob_model, solver=solver, time=time, y0=y0, obs=obs, masks=masks)

    kernel = NUTS(
        model=model, 
        dense_mass=True, 
        step_size=0.01,
        adapt_mass_matrix=True,
        adapt_step_size=True,
        max_tree_depth=10,
        target_accept_prob=0.8,
        init_strategy=init_to_median
    )

    mcmc = MCMC(
        sampler=kernel,
        num_warmup=2000,
        num_samples=2000,
        num_chains=1,
        progress_bar=True,
    )

    # run inference
    mcmc.run(key)
    mcmc.print_summary()

    # create arviz InferenceData
    data_vars = ["Ce", "Ci", "D", "L"]
    loglik = ["Ce_obs", "Ci_obs", "D_obs", "L_obs"]
    dims = ["id", "time"]

    idata = az.from_numpyro(
        mcmc, 
        dims={k: dims for k in data_vars + loglik},
        coords={"time": time, "id": list(range(len(y0)))},
    )

    idata.add_groups({"posterior_predictive": idata.posterior[data_vars]})
    idata.posterior = idata.posterior.drop(data_vars)

    return idata, mcmc, kernel

    
if __name__ == "__main__":

    print(
        f"jax {jax.__version__}",
        f"numpyro {numpyro.__version__}",
        f"diffrax {diffrax.__version__}",
        f"arviz {az.__version__}"
    )

    time = jnp.linspace(0, 120, 100)
    theta_0 = jnp.array([0.1, 0.05, 0.2, 0.01, 50, 0.0001, 1e-8])

    y0_batch = jnp.array(
        [
            [10, 0, 0],
            [20, 0, 0],
            [30, 0, 0],
            [40, 0, 0],
            [50, 0, 0],
        ]
    )

    # just to demonstrate the solver for a single measurement series
    # without a batch dimension this is not really useful.
    result_single = odesolve(y0=jnp.array([10, 0, 0]), time=time, theta=theta_0)


    # with only a few data the model runs very fast.
    result_batch = odesolve_batch(y0=y0_batch, time=time, theta=theta_0)

    # generate random keys
    key = jax.random.PRNGKey(1)
    key, *subkeys = jax.random.split(key, 3)

    obs_noisy_batch = add_noise_jax(
        Y=result_batch, n=10, key=subkeys[0], 
        sigma_ce=0.1, sigma_ci=0.1, sigma_d=0.1
    )

    idata, mcmc, kernel = run_inference(
        key=subkeys[1], 
        solver=odesolve_batch, 
        time=time, 
        y0=y0_batch, 
        obs=obs_noisy_batch
    )

    az.plot_trace(idata)
    az.plot_pair(idata)


    # with a batch of 300 observations the model becomes very slow, once NUTS
    # reaches a nnumber of steps >~ 100
    # y0 here are just dummy data. In reality the y0 s are different
    y0_batch = jnp.repeat(jnp.array([[10, 0, 0]]), repeats=300, axis=0)
    result_batch = odesolve_batch(y0=y0_batch, time=time, theta=theta_0)

    obs_noisy_batch = add_noise_jax(
        Y=result_batch, n=10, key=subkeys[0], 
        sigma_ce=0.1, sigma_ci=0.1, sigma_d=0.1
    )

    idata, mcmc, kernel = run_inference(
        key=subkeys[1], 
        solver=odesolve_batch, 
        time=time, 
        y0=y0_batch, 
        obs=obs_noisy_batch
    )

    az.plot_trace(idata)
    az.plot_pair(idata)

didn’t look at your code in detail but it’s generally expected that MCMC gets quite expensive (eventually prohibitively expensive) as:

  • the number of datapoints increases (or more generally as the cost of computing the likelihood and/or prior increases)
  • the dimensionality of the latent space increases.

in your case you might try:

  • using a GPU
  • using less stringent atol/rtol
  • seeing if any of the tips in this tutorial are applicable
  • trying HMCECS as in this example

Thanks for your response. Knowing the limitations helps a lot.

I already tested most of the approaches you have mentioned without much improvement. So I think, I’ve reached the limit of what I can do with NUTS/MCMC here.