Using lax.scan for time series in NumPyro

@vincentbt thanks for the suggestion. I actually tried adding noise at each time step and did not mention it here (mostly because results did not change much) with the following implementation (which seems quite similar to the one in the post you are mentioning):

def scan_fn(carry, tau_t):
  beta1, beta2, z_prev1, z_prev2 = carry
  z_t = beta1*z_prev1 + beta2*z_prev2 + tau_t
  z_prev1 = z_prev2
  z_prev2 = z_t
  return (beta1, beta2, z_prev1, z_prev2), z_t

def model(T, T_forecast, obs=None):
    beta1 = numpyro.sample(name="beta_1", fn=dist.Normal(loc=0., scale=1))
    beta2 = numpyro.sample(name="beta_2", fn=dist.Normal(loc=0., scale=1)) 
    tau = numpyro.sample(name="tau", fn=dist.HalfCauchy(scale=jnp.ones(T+T_forecast)*))
    sigma = numpyro.sample(name="sigma", fn=dist.HalfCauchy(scale=1))
    z_prev1 = numpyro.sample(name="z_1", fn=dist.Normal(loc=0, scale=1))
    z_prev2 = numpyro.sample(name="z_2", fn=dist.Normal(loc=0, scale=1))
    
    _Z_exp = [z_prev1 + tau[0], z_prev2 + tau[1]]
    __, zs_exp = lax.scan(scan_fn, (beta1, beta2, z_prev1, z_prev2), tau[2:], T+T_forecast-2)
    Z_exp = jnp.concatenate((jnp.array(_Z_exp), zs_exp))
    
    numpyro.sample(name="y_obs", fn=dist.Normal(loc=Z_exp[:T], scale=sigma), obs=obs[:T])
    numpyro.sample(name="y_pred", fn=dist.Normal(loc=Z_exp[T:], scale=sigma), obs=None)
    return Z_exp

but I will definitely take a better look at the post :wink:

@fehiepsi Yes I did try them both and the for-loop version manages to model the data sufficiently well (as well as imputing missing values) while the lax.scan version gives meaningless results (although sampling seems to have converged)… does this answer your question?