Hi NumPyro community,
I’m working on sampling parameter distributions of an ODE model.
with few data (n=5), the model samples very fast on CPU (< 1min), but once I increase the amount of data (n=300). The gradient evaluations take a long time. At some point this becomes infeasible.
So my question is: Have I reached the limit here? Or can I specify the probability model more efficiently, or parallelize the likelihood/gradient evaluation somehow?
Below you find a minimal model with a fast run and a slow run.
I’d appreciate any leads, ideas, or clear answers
Best,
Florian
Dependencies:
python 3.11.5
jax 0.4.21
numpyro 0.13.2
diffrax 0.4.1
arviz 0.15.1
Minimal Example:
from functools import partial
import arviz as az
import jax
import jax.numpy as jnp
import diffrax
from diffrax import (
diffeqsolve,
Dopri5,
ODETerm,
SaveAt,
PIDController,
RecursiveCheckpointAdjoint
)
import numpyro
from numpyro.distributions import (
LogNormal,
Binomial
)
from numpyro.infer import (
init_to_median,
MCMC,
NUTS,
)
def tktd_guts_minimal(t, X, k_i, k_e, k_a, k_r):
"""ODE Model
"""
Ce, Ci, D = X
dCe_dt = 0
dCi_dt = k_i * Ce - k_e * Ci
dD_dt = Ci * k_a - D * k_r
return jnp.array([dCe_dt, dCi_dt, dD_dt])
def survival_jax(t, damage, z, kk, h_b):
"""
survival probability derived from hazard by using the trapezoidal rule
"""
hazard = kk * jnp.where(damage - z < 0, 0, damage - z) + h_b
H = jnp.array([jax.scipy.integrate.trapezoid(
hazard[:i+1], t[:i+1], axis=0) for i in range(len(t))]
)
S = jnp.exp(-H)
return S
@jax.jit
def odesolve_batch(y0, time, theta):
"""small wrapper around odesolve to vectorize computations of multiple samples
efficiently"""
eval = jax.vmap(
partial(odesolve, theta=theta, time=time),
in_axes=(0,)
)
res = eval(y0)
return res
@jax.jit
def odesolve(y0, time, theta):
"""the deterministic solver and the survival function"""
f = lambda t, y, theta: tktd_guts_minimal(t, y, *theta)
term = ODETerm(f)
solver = Dopri5()
saveat = SaveAt(ts=time)
stepsize_controller = PIDController(rtol=1e-6, atol=1e-7)
sol = diffeqsolve(
terms=term,
solver=solver,
t0=time.min(),
t1=time.max(),
dt0=0.1,
y0=y0,
args=theta[:4],
saveat=saveat,
stepsize_controller=stepsize_controller,
adjoint=RecursiveCheckpointAdjoint(),
max_steps=10**6,
throw=False,
)
damage=sol.ys[:, 2]
surv = survival_jax(t=time, damage=damage, z=theta[4], kk=theta[5], h_b=theta[6])
return jnp.column_stack([sol.ys, surv])
@jax.jit
def add_noise_jax(Y, n, key, sigma_ce, sigma_ci, sigma_d):
"""convenience function to add noise to a dataset"""
lognorm = jax.random.lognormal
binom = jax.random.binomial
key, k0, k1, k2, k3 = jax.random.split(key, 5)
# generate noise vectors to be multiplied
y0_noise = lognorm(k0, sigma=sigma_ce, shape=Y[:, :, 0].shape)
y1_noise = lognorm(k1, sigma=sigma_ci, shape=Y[:, :, 1].shape)
y2_noise = lognorm(k2, sigma=sigma_d, shape=Y[:, :, 2].shape)
# do elementwise matrix multiplication along the j-axis of the matrix
y0 = jnp.einsum("jm,jm -> jm", y0_noise, Y[:, :, 0])
y1 = jnp.einsum("jm,jm -> jm", y1_noise, Y[:, :, 1])
y2 = jnp.einsum("jm,jm -> jm", y2_noise, Y[:, :, 2])
y3 = binom(k3, p=Y[:, :, 3], n=n)
# stack along the last axis
return jnp.stack([y0, y1, y2, y3], axis=2)
def prob_model(solver, time, y0, obs=None, masks=None):
"""The probabilistic model to sample from the posterior parameter
distributions of the ODE model"""
# parameters of the deterministic model
k_i = numpyro.sample("k_i", LogNormal(jnp.log(0.1), scale=1))
k_e = numpyro.sample("k_e", LogNormal(jnp.log(0.05), scale=1))
k_a = numpyro.sample("k_a", LogNormal(jnp.log(0.2), scale=1))
k_r = numpyro.sample("k_r", LogNormal(jnp.log(0.01), scale=1))
z = numpyro.sample("z", LogNormal(jnp.log(0.01), scale=1))
kk = numpyro.sample("kk", LogNormal(jnp.log(0.01), scale=1))
h_b = numpyro.sample("h_b", LogNormal(jnp.log(0.01), scale=1))
eps = 1e-8
# parameters of the error model
# currently the scale parameter of the error distribution is fixed (scale=0.1)
# this parameter can also be drawn from a random variable. For instance,
# from a half-normal or a uniform(0, 100) or something positive constrained.
# parameters must be in the correct order, or better provided as a dictionary
# and extracted accordingly
theta = (k_i, k_e, k_a, k_r, z, kk, h_b)
# compute the deterministic model and store the variables
res = solver(theta=theta, time=time, y0=y0)
Ce = numpyro.deterministic("Ce", res[:, :, 0])
Ci = numpyro.deterministic("Ci", res[:, :, 1])
D = numpyro.deterministic("D", res[:, :, 2])
L = numpyro.deterministic("L", res[:, :, 3])
# calculate the likelihood only if observations ar specified
if obs is not None:
numpyro.sample("Ce_obs", LogNormal(jnp.log(Ce + eps), scale=0.1), obs=obs[:, :, 0] + eps)
numpyro.sample("Ci_obs", LogNormal(jnp.log(Ci + eps), scale=0.1), obs=obs[:, :, 1] + eps)
numpyro.sample("D_obs", LogNormal(jnp.log(D + eps), scale=0.1), obs=obs[:, :, 2] + eps)
numpyro.sample("L_obs", Binomial(total_count=10, probs=L), obs=obs[:, :, 3])
def run_inference(key, solver, time, y0, obs=None, masks=None):
"""run MCMC on NUTS kernel and return arviz.InferenceData"""
# initialize the probability model with constants
model = partial(prob_model, solver=solver, time=time, y0=y0, obs=obs, masks=masks)
kernel = NUTS(
model=model,
dense_mass=True,
step_size=0.01,
adapt_mass_matrix=True,
adapt_step_size=True,
max_tree_depth=10,
target_accept_prob=0.8,
init_strategy=init_to_median
)
mcmc = MCMC(
sampler=kernel,
num_warmup=2000,
num_samples=2000,
num_chains=1,
progress_bar=True,
)
# run inference
mcmc.run(key)
mcmc.print_summary()
# create arviz InferenceData
data_vars = ["Ce", "Ci", "D", "L"]
loglik = ["Ce_obs", "Ci_obs", "D_obs", "L_obs"]
dims = ["id", "time"]
idata = az.from_numpyro(
mcmc,
dims={k: dims for k in data_vars + loglik},
coords={"time": time, "id": list(range(len(y0)))},
)
idata.add_groups({"posterior_predictive": idata.posterior[data_vars]})
idata.posterior = idata.posterior.drop(data_vars)
return idata, mcmc, kernel
if __name__ == "__main__":
print(
f"jax {jax.__version__}",
f"numpyro {numpyro.__version__}",
f"diffrax {diffrax.__version__}",
f"arviz {az.__version__}"
)
time = jnp.linspace(0, 120, 100)
theta_0 = jnp.array([0.1, 0.05, 0.2, 0.01, 50, 0.0001, 1e-8])
y0_batch = jnp.array(
[
[10, 0, 0],
[20, 0, 0],
[30, 0, 0],
[40, 0, 0],
[50, 0, 0],
]
)
# just to demonstrate the solver for a single measurement series
# without a batch dimension this is not really useful.
result_single = odesolve(y0=jnp.array([10, 0, 0]), time=time, theta=theta_0)
# with only a few data the model runs very fast.
result_batch = odesolve_batch(y0=y0_batch, time=time, theta=theta_0)
# generate random keys
key = jax.random.PRNGKey(1)
key, *subkeys = jax.random.split(key, 3)
obs_noisy_batch = add_noise_jax(
Y=result_batch, n=10, key=subkeys[0],
sigma_ce=0.1, sigma_ci=0.1, sigma_d=0.1
)
idata, mcmc, kernel = run_inference(
key=subkeys[1],
solver=odesolve_batch,
time=time,
y0=y0_batch,
obs=obs_noisy_batch
)
az.plot_trace(idata)
az.plot_pair(idata)
# with a batch of 300 observations the model becomes very slow, once NUTS
# reaches a nnumber of steps >~ 100
# y0 here are just dummy data. In reality the y0 s are different
y0_batch = jnp.repeat(jnp.array([[10, 0, 0]]), repeats=300, axis=0)
result_batch = odesolve_batch(y0=y0_batch, time=time, theta=theta_0)
obs_noisy_batch = add_noise_jax(
Y=result_batch, n=10, key=subkeys[0],
sigma_ce=0.1, sigma_ci=0.1, sigma_d=0.1
)
idata, mcmc, kernel = run_inference(
key=subkeys[1],
solver=odesolve_batch,
time=time,
y0=y0_batch,
obs=obs_noisy_batch
)
az.plot_trace(idata)
az.plot_pair(idata)