Seeking advice on optimizing Diffrax/Numpyro implementation of Pharmacokinetic-Pharmacodynamic (PK-PD) model
I’ve reimplemented a published PK-PD model using diffrax and numpyro, and I’m looking for advice on optimizing performance, particularly around JAX usage and MCMC sampling efficiency.
I have been working on setting up a Bayesian Model for this problem but have struggled with a long sampling time which hinders my ability to evaluate alternative priors and model structures efficiently. For now I have a very low amount of samples used for illustrative purposes and am seeking advice on if I am following JAX best practices to achieve high performance. If anyone has some advice I would greatly appreciate it.
Model Overview
[Aldea R, Grimm HP, Gieschke R, et al. In silico exploration of amyloid-related imaging abnormalities in the gantenerumab open-label extension trials using a semi-mechanistic model. Alzheimer’s Dement. 2022; 8:e12306. https://doi.org/10.1002/trc2.12306]
- PK: Three-compartment model (absorption, central, peripheral)
- PD: Two-compartment model (amyloid beta, VWD)
- Sigmoidal response function for converting VWD to BGTS score
The model tracks drug concentration through absorption, central, and peripheral compartments, then models its effect on amyloid beta and vascular wall damage (VWD). The BGTS score (biomarker) is calculated from VWD using a sigmoidal response function. Most PK parameters are fixed from literature, while I’m trying to estimate:
- alpha_removal: Rate of amyloid clearance
- k_repair: VWD repair rate
- A_beta0: Initial amyloid level
- response_power: Sigmoidal response steepness
Note: This is not the true data from the paper
I’m using diffrax’s Dopri5 solver with a fixed step size. (PID controllers tend to make the ODE system fail to solve in the max steps)
Current Performance
- Runtime: 1661.61 seconds (27.69 minutes)
- Chains: 2
- Warmup samples: 50
- Post-warmup samples: 100
System Details
- Hardware: Apple M3 Pro
- Python: 3.12
- Key dependencies: jax, numpyro, diffrax
I’m particularly interested in:
- Best practices for JIT/VMAP usage with ODEs and MCMC sampling
- Efficient automatic differentiation setup
- MCMC parameter tuning for this type of model
- Whether blackjax might be more suitable
- Any obvious performance bottlenecks in my implementation
- If this is actually good performance and I am just asking for too much
I would like to be able to run this code efficiently for much longer chains, this initial performance for a relatively small model has me questioning the reasonability to fit larger models in a similar manner.
Full implementation code below. Any suggestions for improving performance would be greatly appreciated!
"""
Example implementation of a PK-PD model using diffrax and numpyro.
This model simulates and fits BGTS (biomarker) data using a system of ODEs.
The model consists of:
- PK: Three-compartment model (absorption, central, peripheral)
- PD: Two-compartment model (amyloid beta, VWD)
- Hill equation for converting VWD to BGTS score
"""
import jax
import jax.numpy as jnp
from jax import random
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, init_to_sample
import diffrax
import arviz as az
import matplotlib.pyplot as plt
import os
import time
import numpy as np
# Set up JAX for parallel chains
num_chains = 2
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={num_chains}"
jax.config.update('jax_platform_name', 'cpu')
# Create synthetic data that matches real BGTS progression pattern
# Times are in days (converted from weeks for the model)
synthetic_times = np.array([0, 12, 16, 24, 28, 36, 40, 48, 60, 72, 84, 96]) * 7
synthetic_bgts = np.array([0.5, 1.0, 2.0, 10.0, 3.0, 2.0, 1.0, 1.0, 2.0, 1.0, 1.0, 0.5])
# Dose schedule: (time in days, dose in mg)
dose_schedule = [
(0, 450), (32, 450), (56, 900), (84, 900),
(102, 1200), (140, 1200), (280, 1200), (309, 1200),
(336, 1200), (365, 1200), (420, 1200), (455, 1200),
(483, 1200), (504, 1200), (529, 1200), (560, 1200),
(588, 1200), (616, 1200), (644, 1200), (675, 1200),
(703, 1200)
]
def combined_model(t, y, args):
"""Combined PK-PD model as a single ODE system
Args:
t: Time point
y: State vector [A, C, Cp, A_beta, VWD]
A: Amount in absorption compartment
C: Central compartment concentration
Cp: Peripheral compartment concentration
A_beta: Amyloid beta level
VWD: VWD level
args: Model parameters (F, D1, KA, CL, Vc, Q, Vp, alpha_removal, k_repair, Dsc)
Returns:
Array of derivatives for each state variable
"""
A, C, Cp, A_beta, VWD = y
F, D1, KA, CL, Vc, Q, Vp, alpha_removal, k_repair, Dsc = args
# PK equations
dA_dt = F * Dsc / D1 - KA * A # Absorption compartment
dC_dt = (KA * A - CL * C - Q * C + Q * Cp) / Vc # Central compartment
dCp_dt = (Q * C - Q * Cp) / Vp # Peripheral compartment
# PD equations
dA_beta_dt = -alpha_removal * C * A_beta # Amyloid beta degradation
dVWD_dt = alpha_removal * C * A_beta - k_repair * VWD # VWD progression
return jnp.array([dA_dt, dC_dt, dCp_dt, dA_beta_dt, dVWD_dt])
def compute_bgts(VWD, BGTS_max=60.0, EG50=1.0, response_power=3.72):
"""Calculate BGTS score from VWD using sigmoidal response function
Args:
VWD: VWD level
BGTS_max: Maximum possible BGTS score
EG50: VWD level producing 50% of max response
response_power: Power parameter controlling steepness of response
Returns:
BGTS score
"""
return BGTS_max * ((VWD/EG50)**response_power) / (1.0 + (VWD/EG50)**response_power)
def apply_dose(t, dose_schedule, D1):
"""Compute dose at time t based on dosing schedule
Args:
t: Current time
dose_schedule: List of (time, dose) tuples
D1: Duration of dose administration
Returns:
Active dose at time t
"""
times = jnp.array([d[0] for d in dose_schedule])
doses = jnp.array([d[1] for d in dose_schedule])
active_doses = jnp.where((t >= times) & (t < times + D1), doses, 0.0)
return jnp.sum(active_doses)
@jax.jit
def simulate_model(params, dose_schedule, t_span):
"""Simulate the combined PK-PD model using diffrax solver
Args:
params: Array of [alpha_removal, k_repair, A_beta0, response_power]
dose_schedule: List of (time, dose) tuples
t_span: Time points to evaluate the model
Returns:
Tuple of (full solution array, BGTS values)
"""
# Fixed PK parameters from literature
F, D1, KA, CL, Vc, Q, Vp = 0.494, 0.0821, 0.220, 0.336, 3.52, 0.869, 6.38
alpha_removal, k_repair, A_beta0, response_power = params
# Initial conditions
y0 = jnp.array([0.0, 0.0, 0.0, A_beta0, 0.0])
def model_rhs(t, y, args):
Dsc = apply_dose(t, dose_schedule, D1)
return combined_model(t, y, (*args, Dsc))
# Solve ODE system using diffrax
solution = diffrax.diffeqsolve(
diffrax.ODETerm(model_rhs),
diffrax.Dopri5(),
t0=0,
t1=t_span[-1],
dt0=0.01,
y0=y0,
saveat=diffrax.SaveAt(ts=t_span), # Save only at data points
args=(F, D1, KA, CL, Vc, Q, Vp, alpha_removal, k_repair),
max_steps=500000
)
# Calculate BGTS from VWD
VWD = solution.ys[:, 4]
BGTS = compute_bgts(VWD, response_power=response_power)
return solution.ys, BGTS
def plot_predictions(samples, data, dose_schedule):
"""Plot model predictions against data with uncertainty bands
Args:
samples: MCMC samples
data: Dictionary containing observed data
dose_schedule: List of (time, dose) tuples
"""
# Create dense time points for smooth plotting
t_plot = jnp.linspace(0, 7*120, 200)
# Generate predictions for each sample
predictions = []
concentrations = []
for i in range(min(100, len(samples['alpha_removal']))): # Use 100 samples for efficiency
params = jnp.array([
samples['alpha_removal'][i],
samples['k_repair'][i],
samples['A_beta0'][i],
samples['response_power'][i]
])
solution, bgts = simulate_model(params, dose_schedule, t_plot)
predictions.append(bgts)
concentrations.append(solution[:, 1])
predictions = jnp.array(predictions)
concentrations = jnp.array(concentrations)
# Calculate means and credible intervals
mean_pred = jnp.mean(predictions, axis=0)
lower_pred = jnp.percentile(predictions, 2.5, axis=0)
upper_pred = jnp.percentile(predictions, 97.5, axis=0)
mean_conc = jnp.mean(concentrations, axis=0)
lower_conc = jnp.percentile(concentrations, 2.5, axis=0)
upper_conc = jnp.percentile(concentrations, 97.5, axis=0)
# Create plots
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10), height_ratios=[1, 1])
# Plot PK concentrations and doses
ax1.fill_between(t_plot/7, lower_conc, upper_conc, alpha=0.2, color='purple',
label='95% CI Central Concentration')
ax1.plot(t_plot/7, mean_conc, 'purple', label='Mean Central Concentration')
# Add doses on secondary y-axis
ax1_doses = ax1.twinx()
dose_times = [d[0]/7 for d in dose_schedule]
dose_levels = [d[1] for d in dose_schedule]
ax1_doses.bar(dose_times, dose_levels, alpha=0.3, color='gray', width=0.5, label='Doses (mg)')
# Configure plots
ax1.set_xlabel('Time (weeks)')
ax1.set_ylabel('Concentration (μg/mL)')
ax1_doses.set_ylabel('Dose (mg)')
ax1.grid(True, alpha=0.3)
lines1, labels1 = ax1.get_legend_handles_labels()
lines2, labels2 = ax1_doses.get_legend_handles_labels()
ax1.legend(lines1 + lines2, labels1 + labels2, loc='upper right')
# Plot BGTS predictions
ax2.fill_between(t_plot/7, lower_pred, upper_pred, alpha=0.2, color='blue',
label='95% CI BGTS')
ax2.plot(t_plot/7, mean_pred, 'b-', label='Mean BGTS')
ax2.scatter(data['times']/7, data['bgts'], color='red', label='Observed Data', zorder=5)
ax2.set_xlabel('Time (weeks)')
ax2.set_ylabel('BGTS Score')
ax2.grid(True, alpha=0.3)
ax2.legend()
ax1.set_title('Central Compartment Concentration and Dosing Schedule')
ax2.set_title('BGTS Score Predictions')
plt.tight_layout()
plt.show()
def run_inference(data, dose_schedule, num_warmup=1000, num_samples=2000, num_chains=num_chains):
"""Run MCMC inference using NUTS sampler
Args:
data: Dictionary containing observed data
dose_schedule: List of (time, dose) tuples
num_warmup: Number of warmup samples per chain
num_samples: Number of post-warmup samples per chain
num_chains: Number of parallel chains
Returns:
Tuple of (MCMC object, posterior samples)
"""
def model(data):
# Sample parameters from priors
alpha_removal = numpyro.sample('alpha_removal',
dist.TruncatedNormal(0.0001, 0.1*0.0001, low=0.0, high=1.0))
k_repair = numpyro.sample('k_repair',
dist.TruncatedNormal(0.01, 0.1*0.01, low=0.0, high=1.0))
A_beta0 = numpyro.sample('A_beta0',
dist.TruncatedNormal(2.0, 0.1*2.0, low=0.0, high=10.0))
response_power = numpyro.sample('response_power',
dist.TruncatedNormal(4.0, 0.05*4.0, low=0.0001, high=10.0))
# Measurement error
sigma = numpyro.sample('sigma', dist.HalfNormal(2.5))
# Compute model prediction
params = jnp.array([alpha_removal, k_repair, A_beta0, response_power])
_, bgts_pred = simulate_model(params, dose_schedule, data['times'])
# Likelihood
numpyro.sample('obs',
dist.Normal(bgts_pred, sigma),
obs=data['bgts'])
# Initialize NUTS sampler
nuts_kernel = NUTS(model,
target_accept_prob=0.8,
max_tree_depth=6,
init_strategy=init_to_sample,
dense_mass=True)
# Run MCMC
mcmc = MCMC(nuts_kernel,
num_warmup=num_warmup,
num_samples=num_samples,
num_chains=num_chains,
chain_method='parallel',
progress_bar=True,
jit_model_args=True)
# Generate random keys for each chain
rng_key = random.PRNGKey(0)
rng_keys = random.split(rng_key, num_chains)
# Run inference
print("Running inference...")
mcmc.run(rng_keys, data)
# Plot diagnostics
samples = mcmc.get_samples()
az.plot_trace(az.from_numpyro(mcmc))
plt.tight_layout()
plt.show()
# Plot predictions
plot_predictions(samples, data, dose_schedule)
return mcmc, samples
if __name__ == "__main__":
start_time = time.time()
# Create data dictionary with synthetic data
data = {
'times': jnp.array(synthetic_times),
'bgts': jnp.array(synthetic_bgts)
}
# Run inference with reduced samples for example
mcmc, samples = run_inference(data, dose_schedule, num_warmup=50, num_samples=100, num_chains=num_chains)
# Print results
print("\nParameter Estimates:")
print("-" * 50)
mcmc.print_summary()
print("\nMean Parameter Values:")
print("-" * 50)
for param in ['alpha_removal', 'k_repair', 'A_beta0', 'response_power', 'sigma']:
mean_value = jnp.mean(samples[param])
std_value = jnp.std(samples[param])
print(f"{param}: {mean_value:.6f} ± {std_value:.6f}")
end_time = time.time()
runtime = end_time - start_time
print("\nRuntime Information:")
print("-" * 50)
print(f"Total runtime: {runtime:.2f} seconds ({runtime/60:.2f} minutes)")