I’m pretty new to PPLs and have been exploring using Monte Carlo methods to fit some experimental physical data. The problem is the code in numpyro is super slow. I’m using a GPU from colab and it seems to spend abound 5 mins in background_compile() and _execute_compiled() before showing the progress bar. At this point it predicts it will take ~100 hours to complete the sampling for 1000 burn-in, 10000 steps and 250 vectorised chains. This seems quite long to me and I’m wondering if anyone with more experience of this believes this is typical or if my code is not optimal? In addition, the GPU memory pretty instantly reaches 13 GB of the 15 GB once running.
The code is below and hopefully well commented. To test everything I’ve simulated data from the model and have fixed all but one of the variables. The time scale is (0, 3200) ns with $\Delta$t = 100 ns. I have tried solving the ODEs with jax.experimental.ode.odeint but have the same problems.
Thanks in advance for any suggestions/insights.
#Import required packages import jax import jax.numpy as jnp from jax.random import PRNGKey from jax.experimental.ode import odeint import numpyro import numpyro.distributions as dist from numpyro.infer import MCMC, NUTS import diffrax import equinox as eqx #Set all floats to 64 bit jax.config.update("jax_enable_x64", True) #Rate equations charge carrier dynamics class BTDP(eqx.Module): """ Rate equations for the BTD model, as per doi.org/10.1039/d0cp04950f p. 28348 including trapping, depopulation and background doping. Model allows for a non-constant density of traps with depopulation from these traps to the valence band. Returns ------- [f0, f1, f2]: array Array of the rate equations. Dynamics for electron, trap and hole concentraions. """ kt: float kb: float kd: float NT: float p0: float def __call__(self, t, y, args): B = self.kb * y * (y + self.p0) T = self.kt * y * (self.NT- y) D = self.kd * y * (y + self.p0) f0 = -B - T f1 = T - D f2 = -B - D return jnp.stack([f0, f1, f2]) #JIT compiled function to solve the ODE @jax.jit def solve_TRPL_BTDP(kt, kb, kd, NT, p0, N0): """ Solve the ODEs for the BTD model. Parameters ---------- kt: float k_T trapping rate constant (cm^3 ns^-1). kb: float k_B bimolecular rate constant (cm^3 ns^-1). kd: float k_D depopulation rate constant (cm^3 ns^-1) (trap to valence band). NT: float Trap density (cm^-3). p0: float Doping density (cm^-3). N0: float Initial electron concentration (cm^-3). Returns ------- sol: array Solution to the ODEs. """ #Define equations btdp = BTDP(kt, kb, kd, NT, p0) terms = diffrax.ODETerm(btdp) #Start and end times t0 = 0.0 t1 = 3200.0 #Initial conditions and initial time step y0 = jnp.array([N0, 0.0, N0 + p0]) dt0 = 0.0002 #Define solver and times to save at solver = diffrax.Kvaerno5() saveat = diffrax.SaveAt(ts=jnp.arange(33, dtype=jnp.float64)*100) #Controller for adaptive time stepping stepsize_controller = diffrax.PIDController(rtol=1e-8, atol=1e-8) #Solve ODEs sol = diffrax.diffeqsolve( terms, solver, t0, t1, dt0, y0, saveat=saveat, stepsize_controller=stepsize_controller, ) return sol #JIT compiled function to calculate the TRPL signal @jax.jit def TRPL_BTDP(kt, kb, kd, NT, p0, y0, N0): """ Calculate the TRPL signal for the BTD model. Parameters ---------- kt: float k_T trapping rate constant (cm^3 ns^-1). kb: float k_B bimolecular rate constant (cm^3 ns^-1). kd: float k_D depopulation rate constant (cm^3 ns^-1) (trap to valence band). NT: float Trap density (cm^-3). p0: float Doping density (cm^-3). y0: float Initial TRPL counts (counts). N0: float Initial electron concentration (cm^-3). Returns ------- sig: array TRPL signal. """ #Solve ODEs sol = solve_TRPL_BTDP(kt, kb, kd, NT, p0, N0) #Calculate TRPL signal sig = jnp.log10(sol.ys[:, 0]) + jnp.log10(sol.ys[:, 2]) #Normalise signal and correct for the initial TRPL counts sig = sig - sig + jnp.log10(y0) return sig #Standardise the data def _standardise(x): """ Standardise the data to have a mean of 0 and a standard deviation of 1. Parameters ---------- x: numpy.ndarray The data to standardise. Returns ------- x: numpy.ndarray The standardised data. """ return (x - x.mean()) / x.std() #JIT compiled function to standardise the data standardise = jax.jit(_standardise) #warm up the JIT standardise(TRPL_BTDP(1e-15, 1e-18, 0, 1e13, 2e13, 6000, 1e15)).block_until_ready() #Bayesian model def model(N0, y, y0): """ Bayesian model for the BTDP model. Parameters ---------- N0: float Initial electron concentration (cm^-3). y: array log10 of the experimental TRPL signal. y0: float Initial TRPL counts (counts). """ #Define priors for the rate constants in log space theta = numpyro.sample( "theta", dist.TruncatedNormal( low = jnp.array([-16]), high = jnp.array([-14]), loc = jnp.array([-15]), scale = jnp.array(), ), ) #Calculate the TRPL signal and standardise #Testing with fixed parameters for all but kt signal = TRPL_BTDP(10**theta, 5e-18, 0, 1e13, 2e13, y0, N0) signal = standardise(signal) #Define the likelihood numpyro.sample("y", dist.Normal(signal, .02), obs=y) #Generate some simulated data to test the fitting #time points fixed in model for now N0 = 1e16 y = TRPL_BTDP(2.2e-15, 5e-18, 0, 1e13, 2e13, 6000, N0) # create a normal distribution object normal = dist.Normal(0, .02) # generate random samples from the normal distribution noise = normal.sample(PRNGKey(84), (33,)) num_warmup, num_samples = 1000, 10000 #Define the MCMC mcmc = MCMC( NUTS(model, dense_mass=True), num_warmup=num_warmup, num_samples=num_samples, num_chains=250, chain_method="vectorized", progress_bar=True, ) #Run the MCMC with the simulated data and noise mcmc.run(PRNGKey(78), N0 = N0, y = standardise(y+noise), y0 = 10**(y+noise)) mcmc.print_summary()