im relatively new to numpyro and use it to infer parameters in a differential equation governing the dynamics of my system. For this purpose, I oriented myself on the tutorial dealing with the Lotka-Volterra equations (http://num.pyro.ai/en/latest/examples/ode.html).
The difference in my case is that I have to integrate the ODE several times for different start conditions (here starting temperatures). My model looks as follows and works just fine:
def model(y=None): # initial population x_init = jnp.array([200.,0.]) # parameters E_A_1, E_A_2, A of dxdt2 E_A_1 = numpyro.sample("E_A_1",dist.TruncatedNormal(loc=5.,scale=20.,low=0.)) E_A_2 = numpyro.sample("E_A_2",dist.TruncatedNormal(loc=5.,scale=20.,low=0.)) A = numpyro.sample("A",dist.TruncatedNormal(loc=15.,scale=30.,low=0.)) # sample sigma ak measurement error sigma = numpyro.sample("sigma", dist.LogNormal(-1, 1).expand()) x = jnp.zeros([len(ts)*len(Ts),2]) for i,T in enumerate(Ts): xi = odeint(dxdt2,x_init,ts,E_A_1,E_A_2,A,T,rtol=1e-6, atol=1e-5, mxstep=1000) x =index_update(x,index[i*len(ts):(i+1)*len(ts),:],xi) # measured populations numpyro.sample("y", dist.Normal(x, sigma), obs=y)
I am wondering how to parallelize the for loop over temperatures (Ts), so that the ode integration is done in parrallel not sequentially. I tried to use
numpyro.plate for this purpose as shown below but this does not work. I think the reason is the
subsample_size of 1 and the fact that one cannot just randomly shuffle all indices as those belonging to one temperature are not conditionally independent and are produced by the same
with numpyro.plate("data_loop",size=len(Ts),subsample_size=1) as ind: Ti = Ts[ind] yi = y[ind*len(ts):(ind+1)*len(ts)] x = odeint(dxdt2,x_init,ts,E_A_1,E_A_2,A,Ti,rtol=1e-6, atol=1e-5, mxstep=1000) numpyro.sample("y", dist.Normal(x, sigma), obs=yi)
Does anybody of you knows a way how to parallelize the excution of
odeint in this example?