Efficiently sampling a large ODE model (compiling issues?)

Hi everyone, I’m new to NumPyro and want to efficiently sample a large ODE model. Ultimately, I would like to scale it up to 30 variables and data of up to 1000 individuals and around 10 timepoints with many missings.

Below you can find my toy model code. My main question is that the code runs efficiently when I just use 10 individuals in the data (2 minutes), but when I use 100 individuals, for instance, the sampler already takes 32 minutes. Most of that time is spent ‘compiling’, after which the progress bar suggests it efficiently samples.

I get this error message at some point:
“Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=–xla_dump_to=/tmp/foo and attach the results.
Compiling module jit__body_fn.274620”

Also, when I use real data, the sampler >12 hours to compile before I just exit the code.

I’m running the code from Jupyter lab. Any ideas as to why the model takes so long to compile/run? I tried turning off the progress bar (was suggested elsewhere) but that didn’t do anything. Once the model actually appears to start sampling it’s lightening fast.

See my reproducible code below:

import matplotlib
import matplotlib.pyplot as plt
from jax.experimental.ode import odeint
import jax.numpy as jnp
import jax
from jax import random, ops
from jax.random import PRNGKey
import time
from tqdm import tqdm
import numpyro
from numpyro.infer.initialization import init_to_median, init_to_sample
import numpyro.distributions as dist
import pandas as pd 
import numpy as n
from numpyro.infer import MCMC, NUTS, Predictive

az.style.use("arviz-darkgrid")

def dz_dt(z, t, A):
    """ Solve the ODEs
    """
    return jnp.matmul(A,z)

def model(df, data, len_na, stocks, dir_mat, RID_list, prior, nan_policy=None):
    """
    """
    #N = #data.shape[0]

    ## Set prior parameters
    
    if prior == 'truncnorm':
        theta = numpyro.sample(
            "theta",        
            dist.TruncatedNormal(
                low=jnp.array([-0.05] * len(stocks)*len(stocks)),
                loc=jnp.array([0.05] * len(stocks)*len(stocks)),
                scale=jnp.array([0.05] * len(stocks)*len(stocks)), 
            ),)
        
    elif prior == 'invgamma':
        pass
        #theta = numpyro.sample(
        #        "theta",        
        #        dist.TruncatedNormal(
        #        low=jnp.array([-0.05] * len(stocks)*len(stocks)),
        #        loc=jnp.array([0.05] * len(stocks)*len(stocks)),
         #       scale=jnp.array([0.05] * len(stocks)*len(stocks)), 
         #   ),)

    A = theta.reshape([len(stocks), len(stocks)]) * dir_mat
    #print(A) 
    
    if nan_policy=='impute':
        #print("Number of NaN: ", len_nan)
        y_imputed = numpyro.sample("y_imputed", dist.Normal(0, 1).expand([len_na]).mask(False))
        y = ops.index_update(data, np.isnan(data), y_imputed)
        
    elif nan_policy == 'impute_bl':   #TODO
        pass
    
    elif nan_policy=='mask':
        nan_idx = ~np.isnan(data)        
        y = jax.ops.index_update(data, np.isnan(data), -999)
        
    elif nan_policy==None:
        y = data
        
    z_list = []
    for i, rid in enumerate(RID_list):
        df_i = df.loc[df.RID == rid, :]  # get dataframe of single individual
        z_i = odeint(dz_dt, 
                     jnp.array(df_i[stocks].values[0]), # Initial population
                     jnp.array(list(df_i['Time'])), # Time-steps to be returned, 
                     A,  #rtol=1e-6, atol=1e-5, mxstep=1000
                     )
        z_list += [z_i[1:,:]]  # Only store non-baseline time-steps
    z = jnp.concatenate(z_list) 

    sigma = numpyro.sample("sigma", dist.LogNormal(1, 1).expand([len(stocks)]))

    if nan_policy=='mask':
        with numpyro.handlers.mask(mask=nan_idx): 
            numpyro.sample("y_pred", dist.Normal(z, sigma), obs=y)
    else:
        numpyro.sample("y_pred", dist.Normal(z, sigma), obs=y)

## Generate data
N = 100  # Number of individuals
t_end = 24
ts = jnp.array(np.array([0.0, 3.0] + list(np.linspace(0, t_end, int(t_end/6) + 1)[1:])))
stocks = ["ADAS13", "FDG", "WMH", "AV45"]

z_init = jnp.array(np.random.multivariate_normal([0, 0, 0, 0], np.array([[1,-0.1,0.1,0.1],[-0.1,1,-0.1,-0.1],[0.1,-0.1,1,0.1], [0.1,-0.1,0.1,1]]), N))

dir_mat = jnp.array([[1, -1, 1, 1],    # Direction matrix
                    [-1, 1, -1, -1],
                    [1, -1, 1, 1],
                    [1, -1, 1, 1]])

A = jnp.array([[0.02, -0.02, 0.03, 0.05], 
               [-0.02, 0, -0.05, -0.03],
               [0.03, -0.02, 0, 0.04],
               [0, 0, 0.03, 0.03]])

z_list = []
for i in range(N):
    z_list += [odeint(dz_dt, z_init[i], ts, A)]

y_real = np.random.normal(jnp.concatenate(z_list), 0.1) 

## Dataframe
df_opt = pd.DataFrame({**{"Time" : list(ts) * N, 
                       "RID" : list(np.array([[i+1]*len(ts) for i in range(N)]).flatten()),
                       **dict(zip(stocks, [y_real[:,i] for i in range(len(stocks))]))}} 
                     )

# Add NaNs
nan_num = 5

if nan_num != 0:
    df_opt_without_na = df_opt.copy()
    for k in range(nan_num):
        df_opt.loc[np.random.randint(0, len(df_opt)-1), stocks[np.random.randint(0, len(stocks)-1)]] = np.nan
    print("Number of NaNs added: ", df_opt.isna().sum().sum())
else:
    print("No NaNs added")

num_samples = 300 #500
warm_up_samples = 200
num_chains = 1
numpyro.set_host_device_count(num_chains)
nan_policy = 'impute' 
prior = 'truncnorm' 

mcmc = MCMC(
        NUTS(model, 
             dense_mass=True,  
             init_strategy=init_to_median()),
        num_warmup=warm_up_samples, 
        num_samples=num_samples,
        num_chains=num_chains,
        progress_bar=True, 
        chain_method="parallel",
        )

start = time.time()
mcmc.run(PRNGKey(1), 
         df=df_opt, 
         data=jnp.array(df_opt.loc[df_opt.Time != 0, stocks]), 
         len_na=df_opt[stocks].isna().sum().sum(),
         stocks=stocks, dir_mat=dir_mat,
         RID_list=list(set(df_opt.RID)),
         prior=prior,
         nan_policy=nan_policy)
end = time.time()

print("The sampling took ", (end-start)/60, " minutes.")
1 Like

Any suggestions? I just tried N=500 using the above code and it hasn’t yet compiled within 12 hours.
Now I could imagine that the sampling actually takes that long because it has to solve 500 ODEs per sample, but the progress bar and error message suggest the sampler hasn’t yet compiled?

The error message I got again: “Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=–xla_dump_to=/tmp/foo and attach the results.
Compiling module jit__body_fn.274620”"

I think this is the source of slowness

for i, rid in enumerate(RID_list):

Each one of odeint can be slow to compile, the amount of numerical operators (to be compiled to XLA) with Python is scaled linearly with the range of your Python loop. In JAX, the compiling time is scaled non-linear w.r.t. the amount of numerical operators, which makes things even worse.

For a solution, you can try to use jax.vmap with your odeint to make it faster. When compiling time is > 1 minute, a rule-of-thumb is to reconsider the approach. Hope it help! Please let me know if it is tricky to use vmap in your code.

2 Likes

Thanks so much for your input! Extremely helpful. It indeed seems that using vmap significantly increases the speed of the repeated ODE solving for large N. Interestingly it seems to be slightly slower for small N but it barely increases in time for larger N.

I’m now using the following code, is that also what you had in mind?
jit(vmap(lambda y0: odeint(dz_dt, y0, ts[1:], A)))(data_bl)

I’ll be sampling from the 4D ODE model with N=500 data overnight and will report on how fast in runs. Jit may also help. Thanks again!

One more question, (sorry I could also open a new topic but I’ve already shared the code here).

When I run Predictive(model, mcmc.get_samples)[‘y_pred’] the simulated values are (strangely) identical for all samples of the posterior. This is especially strange since the sampler does identify the correct parameters so it does perform well.

Any ideas what could be the matter?

It seems strange to me too. Could you share the code for Predictive?

jit(vmap(lambda y0: odeint(dz_dt, y0, ts[1:], A)))(data_bl)

I think jit is not needed and might slow down your program a bit. In addition, looking at your code, it seems that you also need to vmap over df_i[stocks].values[0] and list(df_i['Time']). I guess the former is data_bl and the later is constant ts[1:]?

Thanks!

the df_i terms are indeed data_bl and ts[1:] which I now assign outside of the sample function.

I’ll try running the code without jit and see if it speeds up the process further. With jit N=500 took less than an hour!

Regarding predictive, the code now is: (note that I changed the code I shared earlier slightly by now passing data_bl instead of the dataframe (df_opt) and also some additional settings to assign fewer objects within the sample code.

posterior_samples = mcmc.get_samples()
posterior_predictive = Predictive(model, posterior_samples)(PRNGKey(2), 
                                  data_bl=jnp.array(df_opt.loc[df_opt.Time == 0, stocks]), 
                                 data=jnp.array(df_opt.loc[df_opt.Time != 0, stocks]), 
                                 len_na=df_opt[stocks].isna().sum().sum(),
                                 num_params=sum(sum(abs(dir_mat))), 
                                 stocks=stocks, dir_mat=dir_mat,
                                 RID_list=list(set(df_opt.RID)),
                                 prior=prior,
                                 nan_policy=nan_policy)

and then posterior predictive gives:

{'y_pred': DeviceArray([[[-2.7531822 ,  1.7178011 ,  0.43667883, -1.100166  ],
               [-1.6983924 , -0.19100092, -1.7278224 , -1.1027367 ],
               [-2.4942384 ,  1.4251862 , -0.17151755, -0.7938792 ],
               ...,
               [ 0.8028703 , -0.59440917,  0.62220967, -0.3956238 ],
               [ 0.35060555, -0.89827913,  1.907384  ,  0.24282666],
               [ 5.2929897 , -2.0510108 ,  3.871882  , -0.4974407 ]],
 
              [[-2.7531822 ,  1.7178011 ,  0.43667883, -1.100166  ],
               [-1.6983924 , -0.19100092, -1.7278224 , -1.1027367 ],
               [-2.4942384 ,  1.4251862 , -0.17151755, -0.7938792 ],
               ...,
               [ 0.8028703 , -0.59440917,  0.62220967, -0.3956238 ],
               [ 0.35060555, -0.89827913,  1.907384  ,  0.24282666],
               [ 5.2929897 , -2.0510108 ,  3.871882  , -0.4974407 ]],
 
              [[-2.7531822 ,  1.7178011 ,  0.43667883, -1.100166  ],
               [-1.6983924 , -0.19100092, -1.7278224 , -1.1027367 ],
               [-2.4942384 ,  1.4251862 , -0.17151755, -0.7938792 ],
               ...,
               [ 0.8028703 , -0.59440917,  0.62220967, -0.3956238 ],
               [ 0.35060555, -0.89827913,  1.907384  ,  0.24282666],
               [ 5.2929897 , -2.0510108 ,  3.871882  , -0.4974407 ]],
 
              ...,

where the repeated samples are simply equal to the (non-baseline) data (y) that I use in “obs=y” setting in numpyro.sample in the model function. I must be making a silly mistake.

Ah, I see. You need to set y to None for prediction. All you need is to add a keyword y=None to your model signature, specify y during MCMC run, not specify it when using Predictive. See Bayesian Regression Using NumPyro — NumPyro documentation tutorial. Your model is just a function. You can control its behavior for different settings by adding new arg/kwargs to its signature - and change them when needed like the way you added nan_policy.

1 Like

Amazing, I get it now and it works. Thanks so much!

Hello @yunus, hope you are doing well.

I am having a similar dataset as yours. I wonder how you handled the odeint fitting with different subset of the data?
did you handle it with jit(vmap(lambda y0: odeint(dz_dt, y0, ts[1:], A)))(data_bl) as the end?
How did you feed data of different RID into the odeint?