Hi everyone, I’m new to NumPyro and want to efficiently sample a large ODE model. Ultimately, I would like to scale it up to 30 variables and data of up to 1000 individuals and around 10 timepoints with many missings.
Below you can find my toy model code. My main question is that the code runs efficiently when I just use 10 individuals in the data (2 minutes), but when I use 100 individuals, for instance, the sampler already takes 32 minutes. Most of that time is spent ‘compiling’, after which the progress bar suggests it efficiently samples.
I get this error message at some point:
“Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=–xla_dump_to=/tmp/foo and attach the results.
Compiling module jit__body_fn.274620”
Also, when I use real data, the sampler >12 hours to compile before I just exit the code.
I’m running the code from Jupyter lab. Any ideas as to why the model takes so long to compile/run? I tried turning off the progress bar (was suggested elsewhere) but that didn’t do anything. Once the model actually appears to start sampling it’s lightening fast.
See my reproducible code below:
import matplotlib
import matplotlib.pyplot as plt
from jax.experimental.ode import odeint
import jax.numpy as jnp
import jax
from jax import random, ops
from jax.random import PRNGKey
import time
from tqdm import tqdm
import numpyro
from numpyro.infer.initialization import init_to_median, init_to_sample
import numpyro.distributions as dist
import pandas as pd
import numpy as n
from numpyro.infer import MCMC, NUTS, Predictive
az.style.use("arviz-darkgrid")
def dz_dt(z, t, A):
""" Solve the ODEs
"""
return jnp.matmul(A,z)
def model(df, data, len_na, stocks, dir_mat, RID_list, prior, nan_policy=None):
"""
"""
#N = #data.shape[0]
## Set prior parameters
if prior == 'truncnorm':
theta = numpyro.sample(
"theta",
dist.TruncatedNormal(
low=jnp.array([-0.05] * len(stocks)*len(stocks)),
loc=jnp.array([0.05] * len(stocks)*len(stocks)),
scale=jnp.array([0.05] * len(stocks)*len(stocks)),
),)
elif prior == 'invgamma':
pass
#theta = numpyro.sample(
# "theta",
# dist.TruncatedNormal(
# low=jnp.array([-0.05] * len(stocks)*len(stocks)),
# loc=jnp.array([0.05] * len(stocks)*len(stocks)),
# scale=jnp.array([0.05] * len(stocks)*len(stocks)),
# ),)
A = theta.reshape([len(stocks), len(stocks)]) * dir_mat
#print(A)
if nan_policy=='impute':
#print("Number of NaN: ", len_nan)
y_imputed = numpyro.sample("y_imputed", dist.Normal(0, 1).expand([len_na]).mask(False))
y = ops.index_update(data, np.isnan(data), y_imputed)
elif nan_policy == 'impute_bl': #TODO
pass
elif nan_policy=='mask':
nan_idx = ~np.isnan(data)
y = jax.ops.index_update(data, np.isnan(data), -999)
elif nan_policy==None:
y = data
z_list = []
for i, rid in enumerate(RID_list):
df_i = df.loc[df.RID == rid, :] # get dataframe of single individual
z_i = odeint(dz_dt,
jnp.array(df_i[stocks].values[0]), # Initial population
jnp.array(list(df_i['Time'])), # Time-steps to be returned,
A, #rtol=1e-6, atol=1e-5, mxstep=1000
)
z_list += [z_i[1:,:]] # Only store non-baseline time-steps
z = jnp.concatenate(z_list)
sigma = numpyro.sample("sigma", dist.LogNormal(1, 1).expand([len(stocks)]))
if nan_policy=='mask':
with numpyro.handlers.mask(mask=nan_idx):
numpyro.sample("y_pred", dist.Normal(z, sigma), obs=y)
else:
numpyro.sample("y_pred", dist.Normal(z, sigma), obs=y)
## Generate data
N = 100 # Number of individuals
t_end = 24
ts = jnp.array(np.array([0.0, 3.0] + list(np.linspace(0, t_end, int(t_end/6) + 1)[1:])))
stocks = ["ADAS13", "FDG", "WMH", "AV45"]
z_init = jnp.array(np.random.multivariate_normal([0, 0, 0, 0], np.array([[1,-0.1,0.1,0.1],[-0.1,1,-0.1,-0.1],[0.1,-0.1,1,0.1], [0.1,-0.1,0.1,1]]), N))
dir_mat = jnp.array([[1, -1, 1, 1], # Direction matrix
[-1, 1, -1, -1],
[1, -1, 1, 1],
[1, -1, 1, 1]])
A = jnp.array([[0.02, -0.02, 0.03, 0.05],
[-0.02, 0, -0.05, -0.03],
[0.03, -0.02, 0, 0.04],
[0, 0, 0.03, 0.03]])
z_list = []
for i in range(N):
z_list += [odeint(dz_dt, z_init[i], ts, A)]
y_real = np.random.normal(jnp.concatenate(z_list), 0.1)
## Dataframe
df_opt = pd.DataFrame({**{"Time" : list(ts) * N,
"RID" : list(np.array([[i+1]*len(ts) for i in range(N)]).flatten()),
**dict(zip(stocks, [y_real[:,i] for i in range(len(stocks))]))}}
)
# Add NaNs
nan_num = 5
if nan_num != 0:
df_opt_without_na = df_opt.copy()
for k in range(nan_num):
df_opt.loc[np.random.randint(0, len(df_opt)-1), stocks[np.random.randint(0, len(stocks)-1)]] = np.nan
print("Number of NaNs added: ", df_opt.isna().sum().sum())
else:
print("No NaNs added")
num_samples = 300 #500
warm_up_samples = 200
num_chains = 1
numpyro.set_host_device_count(num_chains)
nan_policy = 'impute'
prior = 'truncnorm'
mcmc = MCMC(
NUTS(model,
dense_mass=True,
init_strategy=init_to_median()),
num_warmup=warm_up_samples,
num_samples=num_samples,
num_chains=num_chains,
progress_bar=True,
chain_method="parallel",
)
start = time.time()
mcmc.run(PRNGKey(1),
df=df_opt,
data=jnp.array(df_opt.loc[df_opt.Time != 0, stocks]),
len_na=df_opt[stocks].isna().sum().sum(),
stocks=stocks, dir_mat=dir_mat,
RID_list=list(set(df_opt.RID)),
prior=prior,
nan_policy=nan_policy)
end = time.time()
print("The sampling took ", (end-start)/60, " minutes.")