Seeking advice on MCMC sampling for an ODE model solved with Diffrax

Seeking advice on optimizing Diffrax/Numpyro implementation of Pharmacokinetic-Pharmacodynamic (PK-PD) model

I’ve reimplemented a published PK-PD model using diffrax and numpyro, and I’m looking for advice on optimizing performance, particularly around JAX usage and MCMC sampling efficiency.
I have been working on setting up a Bayesian Model for this problem but have struggled with a long sampling time which hinders my ability to evaluate alternative priors and model structures efficiently. For now I have a very low amount of samples used for illustrative purposes and am seeking advice on if I am following JAX best practices to achieve high performance. If anyone has some advice I would greatly appreciate it.

Model Overview

[Aldea R, Grimm HP, Gieschke R, et al. In silico exploration of amyloid-related imaging abnormalities in the gantenerumab open-label extension trials using a semi-mechanistic model. Alzheimer’s Dement. 2022; 8:e12306. https://doi.org/10.1002/trc2.12306]

  • PK: Three-compartment model (absorption, central, peripheral)
  • PD: Two-compartment model (amyloid beta, VWD)
  • Sigmoidal response function for converting VWD to BGTS score

The model tracks drug concentration through absorption, central, and peripheral compartments, then models its effect on amyloid beta and vascular wall damage (VWD). The BGTS score (biomarker) is calculated from VWD using a sigmoidal response function. Most PK parameters are fixed from literature, while I’m trying to estimate:

  • alpha_removal: Rate of amyloid clearance
  • k_repair: VWD repair rate
  • A_beta0: Initial amyloid level
  • response_power: Sigmoidal response steepness

Note: This is not the true data from the paper

I’m using diffrax’s Dopri5 solver with a fixed step size. (PID controllers tend to make the ODE system fail to solve in the max steps)

Current Performance

  • Runtime: 1661.61 seconds (27.69 minutes)
  • Chains: 2
  • Warmup samples: 50
  • Post-warmup samples: 100

System Details

  • Hardware: Apple M3 Pro
  • Python: 3.12
  • Key dependencies: jax, numpyro, diffrax

I’m particularly interested in:

  1. Best practices for JIT/VMAP usage with ODEs and MCMC sampling
  2. Efficient automatic differentiation setup
  3. MCMC parameter tuning for this type of model
  4. Whether blackjax might be more suitable
  5. Any obvious performance bottlenecks in my implementation
  6. If this is actually good performance and I am just asking for too much

I would like to be able to run this code efficiently for much longer chains, this initial performance for a relatively small model has me questioning the reasonability to fit larger models in a similar manner.

Full implementation code below. Any suggestions for improving performance would be greatly appreciated!

"""
Example implementation of a PK-PD model using diffrax and numpyro.
This model simulates and fits BGTS (biomarker) data using a system of ODEs.

The model consists of:
- PK: Three-compartment model (absorption, central, peripheral)
- PD: Two-compartment model (amyloid beta, VWD)
- Hill equation for converting VWD to BGTS score
"""

import jax
import jax.numpy as jnp
from jax import random
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, init_to_sample
import diffrax
import arviz as az
import matplotlib.pyplot as plt
import os
import time
import numpy as np

# Set up JAX for parallel chains
num_chains = 2
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={num_chains}"
jax.config.update('jax_platform_name', 'cpu')

# Create synthetic data that matches real BGTS progression pattern
# Times are in days (converted from weeks for the model)
synthetic_times = np.array([0, 12, 16, 24, 28, 36, 40, 48, 60, 72, 84, 96]) * 7
synthetic_bgts = np.array([0.5, 1.0, 2.0, 10.0, 3.0, 2.0, 1.0, 1.0, 2.0, 1.0, 1.0, 0.5])

# Dose schedule: (time in days, dose in mg)
dose_schedule = [
    (0, 450), (32, 450), (56, 900), (84, 900),
    (102, 1200), (140, 1200), (280, 1200), (309, 1200),
    (336, 1200), (365, 1200), (420, 1200), (455, 1200),
    (483, 1200), (504, 1200), (529, 1200), (560, 1200),
    (588, 1200), (616, 1200), (644, 1200), (675, 1200),
    (703, 1200)
]

def combined_model(t, y, args):
    """Combined PK-PD model as a single ODE system
    
    Args:
        t: Time point
        y: State vector [A, C, Cp, A_beta, VWD]
           A: Amount in absorption compartment
           C: Central compartment concentration
           Cp: Peripheral compartment concentration
           A_beta: Amyloid beta level
           VWD: VWD level
        args: Model parameters (F, D1, KA, CL, Vc, Q, Vp, alpha_removal, k_repair, Dsc)
    
    Returns:
        Array of derivatives for each state variable
    """
    A, C, Cp, A_beta, VWD = y
    F, D1, KA, CL, Vc, Q, Vp, alpha_removal, k_repair, Dsc = args
    
    # PK equations
    dA_dt = F * Dsc / D1 - KA * A  # Absorption compartment
    dC_dt = (KA * A - CL * C - Q * C + Q * Cp) / Vc  # Central compartment
    dCp_dt = (Q * C - Q * Cp) / Vp  # Peripheral compartment
    
    # PD equations
    dA_beta_dt = -alpha_removal * C * A_beta  # Amyloid beta degradation
    dVWD_dt = alpha_removal * C * A_beta - k_repair * VWD  # VWD progression
    
    return jnp.array([dA_dt, dC_dt, dCp_dt, dA_beta_dt, dVWD_dt])

def compute_bgts(VWD, BGTS_max=60.0, EG50=1.0, response_power=3.72):
    """Calculate BGTS score from VWD using sigmoidal response function
    
    Args:
        VWD: VWD level
        BGTS_max: Maximum possible BGTS score
        EG50: VWD level producing 50% of max response
        response_power: Power parameter controlling steepness of response
    
    Returns:
        BGTS score
    """
    return BGTS_max * ((VWD/EG50)**response_power) / (1.0 + (VWD/EG50)**response_power)

def apply_dose(t, dose_schedule, D1):
    """Compute dose at time t based on dosing schedule
    
    Args:
        t: Current time
        dose_schedule: List of (time, dose) tuples
        D1: Duration of dose administration
    
    Returns:
        Active dose at time t
    """
    times = jnp.array([d[0] for d in dose_schedule])
    doses = jnp.array([d[1] for d in dose_schedule])
    active_doses = jnp.where((t >= times) & (t < times + D1), doses, 0.0)
    return jnp.sum(active_doses)

@jax.jit
def simulate_model(params, dose_schedule, t_span):
    """Simulate the combined PK-PD model using diffrax solver
    
    Args:
        params: Array of [alpha_removal, k_repair, A_beta0, response_power]
        dose_schedule: List of (time, dose) tuples
        t_span: Time points to evaluate the model
    
    Returns:
        Tuple of (full solution array, BGTS values)
    """
    # Fixed PK parameters from literature
    F, D1, KA, CL, Vc, Q, Vp = 0.494, 0.0821, 0.220, 0.336, 3.52, 0.869, 6.38
    alpha_removal, k_repair, A_beta0, response_power = params
    
    # Initial conditions
    y0 = jnp.array([0.0, 0.0, 0.0, A_beta0, 0.0])
    
    def model_rhs(t, y, args):
        Dsc = apply_dose(t, dose_schedule, D1)
        return combined_model(t, y, (*args, Dsc))
    
    # Solve ODE system using diffrax
    solution = diffrax.diffeqsolve(
        diffrax.ODETerm(model_rhs),
        diffrax.Dopri5(),
        t0=0,
        t1=t_span[-1],
        dt0=0.01,
        y0=y0,
        saveat=diffrax.SaveAt(ts=t_span),  # Save only at data points
        args=(F, D1, KA, CL, Vc, Q, Vp, alpha_removal, k_repair),
        max_steps=500000
    )
    
    # Calculate BGTS from VWD
    VWD = solution.ys[:, 4]
    BGTS = compute_bgts(VWD, response_power=response_power)
    
    return solution.ys, BGTS

def plot_predictions(samples, data, dose_schedule):
    """Plot model predictions against data with uncertainty bands
    
    Args:
        samples: MCMC samples
        data: Dictionary containing observed data
        dose_schedule: List of (time, dose) tuples
    """
    # Create dense time points for smooth plotting
    t_plot = jnp.linspace(0, 7*120, 200)
    
    # Generate predictions for each sample
    predictions = []
    concentrations = []
    for i in range(min(100, len(samples['alpha_removal']))): # Use 100 samples for efficiency
        params = jnp.array([
            samples['alpha_removal'][i],
            samples['k_repair'][i],
            samples['A_beta0'][i],
            samples['response_power'][i]
        ])
        solution, bgts = simulate_model(params, dose_schedule, t_plot)
        predictions.append(bgts)
        concentrations.append(solution[:, 1])
    
    predictions = jnp.array(predictions)
    concentrations = jnp.array(concentrations)
    
    # Calculate means and credible intervals
    mean_pred = jnp.mean(predictions, axis=0)
    lower_pred = jnp.percentile(predictions, 2.5, axis=0)
    upper_pred = jnp.percentile(predictions, 97.5, axis=0)
    
    mean_conc = jnp.mean(concentrations, axis=0)
    lower_conc = jnp.percentile(concentrations, 2.5, axis=0)
    upper_conc = jnp.percentile(concentrations, 97.5, axis=0)
    
    # Create plots
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10), height_ratios=[1, 1])
    
    # Plot PK concentrations and doses
    ax1.fill_between(t_plot/7, lower_conc, upper_conc, alpha=0.2, color='purple', 
                     label='95% CI Central Concentration')
    ax1.plot(t_plot/7, mean_conc, 'purple', label='Mean Central Concentration')
    
    # Add doses on secondary y-axis
    ax1_doses = ax1.twinx()
    dose_times = [d[0]/7 for d in dose_schedule]
    dose_levels = [d[1] for d in dose_schedule]
    ax1_doses.bar(dose_times, dose_levels, alpha=0.3, color='gray', width=0.5, label='Doses (mg)')
    
    # Configure plots
    ax1.set_xlabel('Time (weeks)')
    ax1.set_ylabel('Concentration (μg/mL)')
    ax1_doses.set_ylabel('Dose (mg)')
    ax1.grid(True, alpha=0.3)
    
    lines1, labels1 = ax1.get_legend_handles_labels()
    lines2, labels2 = ax1_doses.get_legend_handles_labels()
    ax1.legend(lines1 + lines2, labels1 + labels2, loc='upper right')
    
    # Plot BGTS predictions
    ax2.fill_between(t_plot/7, lower_pred, upper_pred, alpha=0.2, color='blue', 
                     label='95% CI BGTS')
    ax2.plot(t_plot/7, mean_pred, 'b-', label='Mean BGTS')
    ax2.scatter(data['times']/7, data['bgts'], color='red', label='Observed Data', zorder=5)
    
    ax2.set_xlabel('Time (weeks)')
    ax2.set_ylabel('BGTS Score')
    ax2.grid(True, alpha=0.3)
    ax2.legend()
    
    ax1.set_title('Central Compartment Concentration and Dosing Schedule')
    ax2.set_title('BGTS Score Predictions')
    
    plt.tight_layout()
    plt.show()

def run_inference(data, dose_schedule, num_warmup=1000, num_samples=2000, num_chains=num_chains):
    """Run MCMC inference using NUTS sampler
    
    Args:
        data: Dictionary containing observed data
        dose_schedule: List of (time, dose) tuples
        num_warmup: Number of warmup samples per chain
        num_samples: Number of post-warmup samples per chain
        num_chains: Number of parallel chains
    
    Returns:
        Tuple of (MCMC object, posterior samples)
    """
    
    def model(data):
        # Sample parameters from priors
        alpha_removal = numpyro.sample('alpha_removal', 
            dist.TruncatedNormal(0.0001, 0.1*0.0001, low=0.0, high=1.0))
        k_repair = numpyro.sample('k_repair',
            dist.TruncatedNormal(0.01, 0.1*0.01, low=0.0, high=1.0))
        A_beta0 = numpyro.sample('A_beta0',
            dist.TruncatedNormal(2.0, 0.1*2.0, low=0.0, high=10.0))
        response_power = numpyro.sample('response_power',
            dist.TruncatedNormal(4.0, 0.05*4.0, low=0.0001, high=10.0))
        
        # Measurement error
        sigma = numpyro.sample('sigma', dist.HalfNormal(2.5))
        
        # Compute model prediction
        params = jnp.array([alpha_removal, k_repair, A_beta0, response_power])
        _, bgts_pred = simulate_model(params, dose_schedule, data['times'])
        
        # Likelihood
        numpyro.sample('obs', 
            dist.Normal(bgts_pred, sigma),
            obs=data['bgts'])
    
    # Initialize NUTS sampler
    nuts_kernel = NUTS(model, 
                      target_accept_prob=0.8,
                      max_tree_depth=6,
                      init_strategy=init_to_sample,
                      dense_mass=True)
    
    # Run MCMC
    mcmc = MCMC(nuts_kernel,
                num_warmup=num_warmup,
                num_samples=num_samples,
                num_chains=num_chains,
                chain_method='parallel',
                progress_bar=True,
                jit_model_args=True)
    
    # Generate random keys for each chain
    rng_key = random.PRNGKey(0)
    rng_keys = random.split(rng_key, num_chains)
    
    # Run inference
    print("Running inference...")
    mcmc.run(rng_keys, data)
    
    # Plot diagnostics
    samples = mcmc.get_samples()
    az.plot_trace(az.from_numpyro(mcmc))
    plt.tight_layout()
    plt.show()
    
    # Plot predictions
    plot_predictions(samples, data, dose_schedule)
    
    return mcmc, samples

if __name__ == "__main__":
    start_time = time.time()
    
    # Create data dictionary with synthetic data
    data = {
        'times': jnp.array(synthetic_times),
        'bgts': jnp.array(synthetic_bgts)
    }
    
    # Run inference with reduced samples for example
    mcmc, samples = run_inference(data, dose_schedule, num_warmup=50, num_samples=100, num_chains=num_chains)
    
    # Print results
    print("\nParameter Estimates:")
    print("-" * 50)
    mcmc.print_summary()
    
    print("\nMean Parameter Values:")
    print("-" * 50)
    for param in ['alpha_removal', 'k_repair', 'A_beta0', 'response_power', 'sigma']:
        mean_value = jnp.mean(samples[param])
        std_value = jnp.std(samples[param])
        print(f"{param}: {mean_value:.6f} ± {std_value:.6f}")
    
    end_time = time.time()
    runtime = end_time - start_time
    
    print("\nRuntime Information:")
    print("-" * 50)
    print(f"Total runtime: {runtime:.2f} seconds ({runtime/60:.2f} minutes)")


I think the bottleneck happens at the ODE solver. You can look into things like max_steps or other parameters.

It might be better to use SVI instead if mcmc is slow.

Thank you for the response. That could be the case, I know that the original implementation used avoided using MCMC as well with the SAEM (Stochastic Approximation of the EM algorithm). Do you know of any examples applying SVI to an ODE system?

Also is Hierarchical Modeling possible with SVI? In the future I am interested in having parameters vary by patient due to the structure of my datasets. Patients are unique due to receiving different dosing schedules.

1 Like