Parallelization, plate and ODEs


im relatively new to numpyro and use it to infer parameters in a differential equation governing the dynamics of my system. For this purpose, I oriented myself on the tutorial dealing with the Lotka-Volterra equations (

The difference in my case is that I have to integrate the ODE several times for different start conditions (here starting temperatures). My model looks as follows and works just fine:

def model(y=None):
    # initial population
    x_init = jnp.array([200.,0.])

    # parameters E_A_1, E_A_2, A of dxdt2
    E_A_1 = numpyro.sample("E_A_1",dist.TruncatedNormal(loc=5.,scale=20.,low=0.))
    E_A_2 = numpyro.sample("E_A_2",dist.TruncatedNormal(loc=5.,scale=20.,low=0.))
    A = numpyro.sample("A",dist.TruncatedNormal(loc=15.,scale=30.,low=0.))

    # sample sigma ak measurement error
    sigma = numpyro.sample("sigma", dist.LogNormal(-1, 1).expand([2]))

    x = jnp.zeros([len(ts)*len(Ts),2])
    for i,T in enumerate(Ts):
        xi = odeint(dxdt2,x_init,ts,E_A_1,E_A_2,A,T,rtol=1e-6, atol=1e-5, mxstep=1000)
        x =index_update(x,index[i*len(ts):(i+1)*len(ts),:],xi)

    # measured populations
    numpyro.sample("y", dist.Normal(x, sigma), obs=y)

I am wondering how to parallelize the for loop over temperatures (Ts), so that the ode integration is done in parrallel not sequentially. I tried to use numpyro.plate for this purpose as shown below but this does not work. I think the reason is the subsample_size of 1 and the fact that one cannot just randomly shuffle all indices as those belonging to one temperature are not conditionally independent and are produced by the same odeint call.

with numpyro.plate("data_loop",size=len(Ts),subsample_size=1) as ind:
    Ti = Ts[ind[0]]
    yi = y[ind[0]*len(ts):(ind[0]+1)*len(ts)]
    x = odeint(dxdt2,x_init,ts,E_A_1,E_A_2,A,Ti,rtol=1e-6, atol=1e-5, mxstep=1000)
    numpyro.sample("y", dist.Normal(x, sigma), obs=yi)

Does anybody of you knows a way how to parallelize the excution of odeint in this example?



I think what you are looking for is vmap.

Thanks for the quick help! I am also new to jax :smiley:

Here you can vmap over Ts like this:

x = jax.vmap(lambda T: odeint(dxdt2,x_init,ts,E_A_1,E_A_2,A,T,
x = x.reshape((-1, 2))  # shape: len(Ts) x len(ts) x 2 -> (-1,) x 2