Efficiently sampling a large ODE model (compiling issues?)

I think this is the source of slowness

for i, rid in enumerate(RID_list):

Each one of odeint can be slow to compile, the amount of numerical operators (to be compiled to XLA) with Python is scaled linearly with the range of your Python loop. In JAX, the compiling time is scaled non-linear w.r.t. the amount of numerical operators, which makes things even worse.

For a solution, you can try to use jax.vmap with your odeint to make it faster. When compiling time is > 1 minute, a rule-of-thumb is to reconsider the approach. Hope it help! Please let me know if it is tricky to use vmap in your code.

2 Likes