 # Parallelization, plate and ODEs

Hi,

im relatively new to numpyro and use it to infer parameters in a differential equation governing the dynamics of my system. For this purpose, I oriented myself on the tutorial dealing with the Lotka-Volterra equations (http://num.pyro.ai/en/latest/examples/ode.html).

The difference in my case is that I have to integrate the ODE several times for different start conditions (here starting temperatures). My model looks as follows and works just fine:

``````def model(y=None):
# initial population
x_init = jnp.array([200.,0.])

# parameters E_A_1, E_A_2, A of dxdt2
E_A_1 = numpyro.sample("E_A_1",dist.TruncatedNormal(loc=5.,scale=20.,low=0.))
E_A_2 = numpyro.sample("E_A_2",dist.TruncatedNormal(loc=5.,scale=20.,low=0.))
A = numpyro.sample("A",dist.TruncatedNormal(loc=15.,scale=30.,low=0.))

# sample sigma ak measurement error
sigma = numpyro.sample("sigma", dist.LogNormal(-1, 1).expand())

x = jnp.zeros([len(ts)*len(Ts),2])
for i,T in enumerate(Ts):
xi = odeint(dxdt2,x_init,ts,E_A_1,E_A_2,A,T,rtol=1e-6, atol=1e-5, mxstep=1000)
x =index_update(x,index[i*len(ts):(i+1)*len(ts),:],xi)

# measured populations
numpyro.sample("y", dist.Normal(x, sigma), obs=y)
``````

I am wondering how to parallelize the for loop over temperatures (Ts), so that the ode integration is done in parrallel not sequentially. I tried to use `numpyro.plate` for this purpose as shown below but this does not work. I think the reason is the `subsample_size` of 1 and the fact that one cannot just randomly shuffle all indices as those belonging to one temperature are not conditionally independent and are produced by the same `odeint` call.

``````with numpyro.plate("data_loop",size=len(Ts),subsample_size=1) as ind:
Ti = Ts[ind]
yi = y[ind*len(ts):(ind+1)*len(ts)]
x = odeint(dxdt2,x_init,ts,E_A_1,E_A_2,A,Ti,rtol=1e-6, atol=1e-5, mxstep=1000)
numpyro.sample("y", dist.Normal(x, sigma), obs=yi)
``````

Does anybody of you knows a way how to parallelize the excution of `odeint` in this example?

Best,

Johannes

I think what you are looking for is vmap.

Thanks for the quick help! I am also new to jax Here you can vmap over Ts like this:

``````x = jax.vmap(lambda T: odeint(dxdt2,x_init,ts,E_A_1,E_A_2,A,T,
rtol=1e-6,atol=1e-5,mxstep=1000))(Ts)
x = x.reshape((-1, 2))  # shape: len(Ts) x len(ts) x 2 -> (-1,) x 2
``````