Breaking up the integration in multiple steps slows down the computation

I need to do inference on a dynamical system. But I need to split up the integration at every time step. To integrate I am using the from jax.experimental.ode import odeint.
Take for example the predator-pray example available in the documentation.

If instead of doing

z = odeint(dz_dt, z_init, ts, theta, rtol=1e-6, atol=1e-5, mxstep=1000)

I do

z_init = jnp.array([1,20]) + numpyro.sample("z_init", dist.Normal(jnp.array([0,0]), jnp.array([0.001,0.00001])))
                            
ts = jnp.arange(float(N))
z_all = [x_init]
for i in range(N-1):
    z_plus = odeint(dz_dt, z_init, ts[i:i+2], theta, rtol=1e-6, atol=1e-5, mxstep=1000)
    z_all.append(z_plus[1,:])
    z_init = z_plus[1,:]

Then I am using NUTS to infer the parameters of the model dz_dt. The time to set-up the inference and to actually do the iterations is around 20/30 times larger in this case.
Is there a way to avoid this larger computational time? I do need to do the integration in steps, so doing by calling the odeint at once is not an option.

why do you ned to split it up like that?

also this appears to be more of a jax question than a numpyro question.

I could be that is a Jax question actually.

But to answer your question, the ODE I am trying to solve has some discrete inputs that enter at every sampling time. (you can think of it as Dirac delta inputs that enter the ODE at every sampling time)
So after every sampling time, I need to add or subtract something to the z_plus to model the effects of these inputs. For this reason I need to split the integration.

are these inputs deterministic? are they known before hand?

yes they are known beforehand and are deterministic

can’t you modify dz_dt (the first arg to ode_int) to take these into account up front?

I don’t think so. The dz_dt is the right-hand side of a continuous-time ODE. I need to integrate to the next sampling time before adding those inputs. I don’t see a way to modify the dz_dt to take that into account.

can you describe your ode using equations? from what you’re saying i don’t see why this isn’t just some form of reparameterization…

I’ll think about it :). Thanks for your help!