MCMC fitting solution of coupled ODEs super slow

I’m pretty new to PPLs and have been exploring using Monte Carlo methods to fit some experimental physical data. The problem is the code in numpyro is super slow. I’m using a GPU from colab and it seems to spend abound 5 mins in background_compile() and _execute_compiled() before showing the progress bar. At this point it predicts it will take ~100 hours to complete the sampling for 1000 burn-in, 10000 steps and 250 vectorised chains. This seems quite long to me and I’m wondering if anyone with more experience of this believes this is typical or if my code is not optimal? In addition, the GPU memory pretty instantly reaches 13 GB of the 15 GB once running.

The code is below and hopefully well commented. To test everything I’ve simulated data from the model and have fixed all but one of the variables. The time scale is (0, 3200) ns with $\Delta$t = 100 ns. I have tried solving the ODEs with jax.experimental.ode.odeint but have the same problems.

#Import required packages
import jax
import jax.numpy as jnp
from jax.random import PRNGKey
from jax.experimental.ode import odeint

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

import diffrax
import equinox as eqx

#Set all floats to 64 bit
jax.config.update("jax_enable_x64", True)

#Rate equations charge carrier dynamics
class BTDP(eqx.Module):

    Rate equations for the BTD model, as per p. 28348
    including trapping, depopulation and background doping.

    Model allows for a non-constant density of traps with depopulation from these
    traps to the valence band.

    [f0, f1, f2]: array
        Array of the rate equations. Dynamics for electron, trap and hole


    kt: float
    kb: float
    kd: float
    NT: float
    p0: float

    def __call__(self, t, y, args):
        B = self.kb * y[0] * (y[2] + self.p0)
        T = self.kt * y[0] * (self.NT- y[1])
        D = self.kd * y[1] * (y[2] + self.p0)
        f0 = -B - T 
        f1 = T - D
        f2 = -B - D
        return jnp.stack([f0, f1, f2])

#JIT compiled function to solve the ODE
def solve_TRPL_BTDP(kt, kb, kd, NT, p0, N0):
    Solve the ODEs for the BTD model.

    kt: float
        k_T trapping rate constant (cm^3 ns^-1).

    kb: float
        k_B bimolecular rate constant (cm^3 ns^-1).

    kd: float
        k_D depopulation rate constant (cm^3 ns^-1) (trap to valence band).

    NT: float
        Trap density (cm^-3).

    p0: float
        Doping density (cm^-3).

    N0: float
        Initial electron concentration (cm^-3).
    sol: array
        Solution to the ODEs.


    #Define equations
    btdp = BTDP(kt, kb, kd, NT, p0)
    terms = diffrax.ODETerm(btdp)

    #Start and end times
    t0 = 0.0
    t1 = 3200.0

    #Initial conditions and initial time step
    y0 = jnp.array([N0, 0.0, N0 + p0])
    dt0 = 0.0002

    #Define solver and times to save at
    solver = diffrax.Kvaerno5()
    saveat = diffrax.SaveAt(ts=jnp.arange(33, dtype=jnp.float64)*100)

    #Controller for adaptive time stepping
    stepsize_controller = diffrax.PIDController(rtol=1e-8, atol=1e-8)
    #Solve ODEs
    sol = diffrax.diffeqsolve(
    return sol

#JIT compiled function to calculate the TRPL signal
def TRPL_BTDP(kt, kb, kd, NT, p0, y0, N0):
    Calculate the TRPL signal for the BTD model.

    kt: float
        k_T trapping rate constant (cm^3 ns^-1).
    kb: float
        k_B bimolecular rate constant (cm^3 ns^-1).
    kd: float
        k_D depopulation rate constant (cm^3 ns^-1) (trap to valence band).

    NT: float
        Trap density (cm^-3).

    p0: float
        Doping density (cm^-3).

    y0: float
        Initial TRPL counts (counts).
    N0: float
        Initial electron concentration (cm^-3).

    sig: array
        TRPL signal.

    #Solve ODEs
    sol = solve_TRPL_BTDP(kt, kb, kd, NT, p0, N0)

    #Calculate TRPL signal
    sig = jnp.log10(sol.ys[:, 0]) + jnp.log10(sol.ys[:, 2])

    #Normalise signal and correct for the initial TRPL counts
    sig = sig - sig[0] + jnp.log10(y0)

    return sig

#Standardise the data
def _standardise(x):
    Standardise the data to have a mean of 0 and a standard deviation of 1.

    x:  numpy.ndarray
        The data to standardise.

    x:  numpy.ndarray
        The standardised data.
    return (x - x.mean()) / x.std()

#JIT compiled function to standardise the data
standardise = jax.jit(_standardise)

#warm up the JIT
standardise(TRPL_BTDP(1e-15, 1e-18, 0, 1e13, 2e13, 6000, 1e15)).block_until_ready()

#Bayesian model
def model(N0, y, y0):
    Bayesian model for the BTDP model.

    N0: float
        Initial electron concentration (cm^-3).

    y: array
        log10 of the experimental TRPL signal.
    y0: float
        Initial TRPL counts (counts).


    #Define priors for the rate constants in log space
    theta = numpyro.sample(
            low   = jnp.array([-16]),
            high  = jnp.array([-14]),
            loc   = jnp.array([-15]),
            scale = jnp.array([1]),
    #Calculate the TRPL signal and standardise
    #Testing with fixed parameters for all but kt
    signal = TRPL_BTDP(10**theta[0], 5e-18, 0, 1e13, 2e13, y0, N0)
    signal = standardise(signal)

    #Define the likelihood
    numpyro.sample("y", dist.Normal(signal, .02), obs=y)

#Generate some simulated data to test the fitting
#time points fixed in model for now
N0 = 1e16
y = TRPL_BTDP(2.2e-15, 5e-18, 0, 1e13, 2e13, 6000, N0)

# create a normal distribution object
normal = dist.Normal(0, .02)

# generate random samples from the normal distribution
noise = normal.sample(PRNGKey(84), (33,))

num_warmup, num_samples = 1000, 10000

#Define the MCMC
mcmc = MCMC(
        NUTS(model, dense_mass=True),
#Run the MCMC with the simulated data and noise, N0 = N0, y = standardise(y+noise), y0 = 10**(y+noise)[0])

do i understand correctly that you want to integrate across 3200 / 100ns ~ 10^9 time steps?

Apologies, it’s 0, 100, 200, … , 3200 ns, so 33 points. The data will look like below. The blue dots are the data with noise added and the orange line the function without noise. X-axis is Time (ns), y-axis Time Resolved PL counts

  • it’s hard to say what’s going on but i suspect your problem is units (or at least that’s one problem). instead of using your favorite SI units use units that result in all relevant quantities being order unity. this will make it much easier for HMC to explore the model log density.

  • you might also make sure you’re using 64-bit precision.

  • also i don’t suggest using so many chains, at least to begin with. first get things to work with 1 chain. you might naively suppose you can just scale up the number of chains without paying much computational cost but that is not the case. both NUTS and ode integration are/can be adaptive algorithms that take variable number of steps and so they cannot be parallelized without getting dinged for that variability.

Thank you for your suggestions, I’ve updated the units so that the values for the fit are now all 1e15 larger. e.g.

TRPL_BTDP(1e-15, 1e-18, 0, 1e13, 2e13, 6000, 1e15)

is now

TRPL_BTDP(1, 1e-3, 0, 1e-2, 2e-2, 6000, 1)

(6000 is a multiplication factor for the normalised answer and is not fit so not altered here.)

However, the units and parameters are such that some will have they all have the same units but will still not be order unity. But should the original code not account for this by using the values in log space, or am I misunderstanding?

My understanding is that this line ensures 64-bit precision is used.

Regarding your last point, I understand how starting with 1 chain as a test case is appropriate, but I’m a bit confused about the scaling. If the individual chains are independent and do not interact with each other, as is my understanding, and we use a T4 GPU on Colab with 2560 CUDA cores to run 250 chains in parallel, taking into account the ODE solver is implemented in JAX and accounting for the variability in the number of steps for both the solver and NUTS, can we expect a computation that takes about 1 minute for 1 chain on a CPU to take more than 100 hours on a GPU for 250 chains in parallel?

Finally regarding your first point, what would possibly make it easier to say what’s going on here?


Imagine a computation that takes a stochastic amount of time. If there is one thread, the relevant quantity is the mean of that time distribution. If there are many threads that are synced in lock step the overall rime is controlled by the slowest thread and thus the tail of the distribution.

That makes sense thank you!