I think this is the source of slowness
for i, rid in enumerate(RID_list):
Each one of odeint
can be slow to compile, the amount of numerical operators (to be compiled to XLA) with Python is scaled linearly with the range of your Python loop. In JAX, the compiling time is scaled non-linear w.r.t. the amount of numerical operators, which makes things even worse.
For a solution, you can try to use jax.vmap
with your odeint
to make it faster. When compiling time is > 1 minute, a rule-of-thumb is to reconsider the approach. Hope it help! Please let me know if it is tricky to use vmap
in your code.