How to feed data with irregular timespan into the hierarchical model?

Hello! First of all I would like to thank all the efforts the dev have put into this library and I really appreciate your work!

I am new to numpyro and recently I am testing out the partial pooling on hierarchical time series data (most insights from the prey and predator tutorial). Let’s say we have dataset of two parallel 2D time series (prey & predator) from the same region, but with different time span (Ts=15 & Ts=25).

While setting up the hierarchical levels is manageable, I am stuck on how to fit the long time series data into the odeint according to its grouping.

Currently, I have a model like this:

def hierarchical_model(y=None, level=None):
   
    z_init = numpyro.sample("z_init", dist.LogNormal(1, 1).expand([2]))
    N = y.shape[0]
    ts = jnp.arange(float(N))
    
    r_μ = numpyro.sample("r_μ", dist.HalfNormal(1, 3))
    c_μ = numpyro.sample("c_μ", dist.HalfNormal(5, 5))
    m_μ = numpyro.sample("m_μ", dist.HalfNormal(1, 1))
    e_μ = numpyro.sample("e_μ", dist.HalfNormal(1, 1))
    
    r_σ = numpyro.sample("r_σ", dist.HalfNormal(1, 1))
    c_σ = numpyro.sample("c_σ", dist.HalfNormal(1, 1))
    m_σ = numpyro.sample("m_σ", dist.HalfNormal(1, 1))
    e_σ = numpyro.sample("e_σ", dist.HalfNormal(1, 1))

    nr_levels = len(np.unique(level))
   
    with numpyro.plate("regions", nr_levels):
        
        theta_r = numpyro.sample("theta_r", dist.Normal(r_μ, r_σ))
        theta_c = numpyro.sample("theta_c", dist.Normal(c_μ, c_σ))
        theta_m = numpyro.sample("theta_m", dist.Normal(m_μ, m_σ))
        theta_e = numpyro.sample("theta_e", dist.Normal(e_μ, e_σ))
  
    z = numpyro.deterministic('z', odeint(dz_dt_h, z_init, ts, theta_r[level], theta_c[level], theta_m[level], theta_e[level], mxstep=1000) )
    sigma = numpyro.sample("sigma", dist.LogNormal(1, 1).expand([2]))
    numpyro.sample("y", dist.Normal(z, sigma), obs=y)


mcmc = MCMC(
    NUTS(hierarchical_model, dense_mass=True),
    num_warmup=1000,
    num_samples=1000,
    num_chains=1,
    progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
)

mcmc.run(PRNGKey(1262),y=values, level=grouping)
mcmc.print_summary()

with values (data) in the long format of shape = (40, 2):

DeviceArray([[ 0,  0],
 [ 0,  0],
 [ 0,  7],
 [ 1,  3],
.....
 [ 0, 43],
 [ 0, 60],
 [ 0, 46],
 [ 0, 43]], dtype=int32)

and grouping in the shape of (40,) to describe the grouping like this:

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])

ATM I am feeding all the data to the model without considering the grouping into odeint (as you can see, the ts takes in the whole dataset). However, it is incorrect since the 2 groups of time series are not sequential. Subsampling seems like an option but the time span of each group of data is not fixed (one 15, one 25). So I wonder if anyone has any idea on how to feed the data correctly?
I ran though the Q&A and this post seems to be the most equivalent. I hope I didn’t miss anything.

Hi, I think I have figured it out and here is my approach.
But I have 3 questions regarding the use of vmap.

  1. the vmap (13min) is actually slower than the sequential method (11min), what would be the reason? Am I implementing it incorrectly?
  2. Since my data is sequential for time series, the order is quite important. But I noticed that, after the vmap, the output order is not the same as what my input data is. I wonder if this would create an divergence in the inference, since I put the numpyro.sample(“y”, dist.Normal(z, sigma), obs=y) at the end? Should I do it separately within the plates?
  3. For the vmap method, since it is a static function, I couldn’t dynamically feed the time period of different length into the function. I wonder if there is any other way of doing this?

Thanks a lot! :slight_smile:

if method=='Normal':
      z_list = []
      for i in np.unique(level):
      # measurement times
          ts = jnp.arange(float(len(data[data.comp_code==i].Week)))
          # integrate dz/dt, the result will have shape N x 2 => population result
          z = numpyro.deterministic('z', odeint(dz_dt_h, z_init, ts, theta_r[i], theta_c[i], theta_m[i], theta_e[i], theta_K[i], mxstep=10000) )
          z_list += z
      z = jnp.vstack(z_list)
        
  if method=='Vmap':
      arr = np.array(data.groupby(['comp_code'])['Week'].count()).astype(float)
      x = jax.vmap(lambda i: odeint(dz_dt_h, z_init, jnp.arange(float(17)), theta_r[i], theta_c[i], theta_m[i], theta_e[i], theta_K[i], mxstep=10000) )(jnp.arange(len(arr)))
      z = jnp.vstack(x)

sigma = numpyro.sample("sigma", dist.LogNormal(1, 1).expand([4]))
# measured populations (range of population results)
numpyro.sample("y", dist.Normal(z, sigma), obs=y)

hi fehiepsi, I am sorry to tag you for this.
But I am wondering if you have any insight on the performance problem with vmap?

And at the same time, I wonder if the plating is correct?
Now the z is 2d array of 2 time series combined and thus the y is sampled the same way. but I wonder if I should put the sampling of y within the plates?

I’m not familiar with vmap of odeint. I think you can benchmark things separately, outside of numpyro context.