Using lax.scan for time series in NumPyro

Hi! I recently started using numpyro for time series models and am having some problems in using lax.scan functionalities.

In particular, here I am trying to implement a simple state-space model with order-2 autoregressive dynamics, something like:

image.

A vanilla (i.e. without lax.scan) implementation would look something like the following:

def model(T, T_forecast, obs=None):
    beta1 = numpyro.sample(name="beta_1", fn=dist.Normal(loc=0., scale=1))
    beta2 = numpyro.sample(name="beta_2", fn=dist.Normal(loc=0., scale=1))
    tau = numpyro.sample(name="tau", fn=dist.HalfCauchy(scale=1))
    sigma = numpyro.sample(name="sigma", fn=dist.HalfCauchy(scale=1))
    z_prev1 = numpyro.sample(name="z_1", fn=dist.Normal(loc=0, scale=1))
    z_prev2 = numpyro.sample(name="z_2", fn=dist.Normal(loc=0, scale=1))
    
    Z = []
    Y = []
    for t in range(2, T):
        z_t_mean = beta1*z_prev1 + beta2*z_prev2
        z_t = numpyro.sample(name="z_%d"%(t+1), fn=dist.Normal(loc=z_t_mean, scale=tau))
        Z.append(z_t)
        z_prev1 = z_prev2
        z_prev2 = z_t
    
    for t in range(T, T+T_forecast):
        z_t_mean = beta1*z_prev1 + beta2*z_prev2
        z_t = numpyro.sample(name="z_%d"%(t+1), fn=dist.Normal(loc=z_t_mean, scale=tau))
        Z.append(z_t)
        z_prev1 = z_prev2
        z_prev2 = z_t
    
    numpyro.sample(name="y_obs", fn=dist.Normal(loc=np.array(Z[:T-2]), scale=sigma), obs=obs[:T-2])
    numpyro.sample(name="y_pred", fn=dist.Normal(loc=np.array(Z[T-2:]), scale=sigma), obs=None)
    return Z

where I am trying to be Bayesian over all model parameters (and predictions). This implementation works fine, however, because of the for-loops, this simple model takes way too much time to compile… Following the time-series tutorial on the numpyro documentation I came up with what would seem to me as an equivalent implementation (this time using lax.scan):

def scan_fn(carry, x):
  beta1, beta2, z_prev1, z_prev2 = carry
  z_t = beta1*z_prev1 + beta2*z_prev2
  z_prev1 = z_prev2
  z_prev2 = z_t
  return (beta1, beta2, z_prev1, z_prev2), z_t
  
def model(T, T_forecast, obs=None):
    beta1 = numpyro.sample(name="beta_1", fn=dist.Normal(loc=0., scale=1))
    beta2 = numpyro.sample(name="beta_2", fn=dist.Normal(loc=0., scale=1)) 
    tau = numpyro.sample(name="tau", fn=dist.HalfCauchy(scale=1))
    sigma = numpyro.sample(name="sigma", fn=dist.HalfCauchy(scale=1))
    z_prev1 = numpyro.sample(name="z_1", fn=dist.Normal(loc=0, scale=1))
    z_prev2 = numpyro.sample(name="z_2", fn=dist.Normal(loc=0, scale=1))
    
    _Z_exp = [z_prev1, z_prev2]
    __, zs_exp = lax.scan(scan_fn, (beta1, beta2, z_prev1, z_prev2), None, T+T_forecast-2)
    Z_exp = jnp.concatenate((jnp.array(_Z_exp), zs_exp), axis=0)
    
    Z = numpyro.sample(name="Z", fn=dist.Normal(loc=Z_exp, scale=tau))
    numpyro.sample(name="y_obs", fn=dist.Normal(loc=Z[:T], scale=sigma), obs=obs[:T])
    numpyro.sample(name="y_pred", fn=dist.Normal(loc=Z[T:], scale=sigma), obs=None)
    return Z_exp

where I use lax.scan to collect the expected values for the latent state z_t. The problem is that when I try running this model, MCMC inference does not converge to meaningful values… Am I missing something evident on how the lax.scan function works or should I implement the model differently?

Thanks!

You can have a look at this code to add noise (parameter tau) at each timestep

2 Likes

@Daniele Have you tested if your lax.scan version and for loop version gives the same result?

1 Like

@vincentbt thanks for the suggestion. I actually tried adding noise at each time step and did not mention it here (mostly because results did not change much) with the following implementation (which seems quite similar to the one in the post you are mentioning):

def scan_fn(carry, tau_t):
  beta1, beta2, z_prev1, z_prev2 = carry
  z_t = beta1*z_prev1 + beta2*z_prev2 + tau_t
  z_prev1 = z_prev2
  z_prev2 = z_t
  return (beta1, beta2, z_prev1, z_prev2), z_t

def model(T, T_forecast, obs=None):
    beta1 = numpyro.sample(name="beta_1", fn=dist.Normal(loc=0., scale=1))
    beta2 = numpyro.sample(name="beta_2", fn=dist.Normal(loc=0., scale=1)) 
    tau = numpyro.sample(name="tau", fn=dist.HalfCauchy(scale=jnp.ones(T+T_forecast)*))
    sigma = numpyro.sample(name="sigma", fn=dist.HalfCauchy(scale=1))
    z_prev1 = numpyro.sample(name="z_1", fn=dist.Normal(loc=0, scale=1))
    z_prev2 = numpyro.sample(name="z_2", fn=dist.Normal(loc=0, scale=1))
    
    _Z_exp = [z_prev1 + tau[0], z_prev2 + tau[1]]
    __, zs_exp = lax.scan(scan_fn, (beta1, beta2, z_prev1, z_prev2), tau[2:], T+T_forecast-2)
    Z_exp = jnp.concatenate((jnp.array(_Z_exp), zs_exp))
    
    numpyro.sample(name="y_obs", fn=dist.Normal(loc=Z_exp[:T], scale=sigma), obs=obs[:T])
    numpyro.sample(name="y_pred", fn=dist.Normal(loc=Z_exp[T:], scale=sigma), obs=None)
    return Z_exp

but I will definitely take a better look at the post :wink:

@fehiepsi Yes I did try them both and the for-loop version manages to model the data sufficiently well (as well as imputing missing values) while the lax.scan version gives meaningless results (although sampling seems to have converged)… does this answer your question?

@Daniele I meant a test for the output of lax.scan and for loop givens the same inputs.

Looking at two models again, it seems to me that they are different. In the first model, z_prev2 = z_t, which is a noisy sample of z_t_mean, while in the second model, z_prev2 = z_t_mean. To make them equivalent (up to reparamterization), I think you can sample the noise first:

noises = numpyro.sample("noises", fn=dist.Normal(0, 1), sample_shape=(T,))

then in lax.scan, you can convert the statement

z_t = numpyro.sample(name="z_%d"%(t+1), fn=dist.Normal(loc=z_t_mean, scale=tau))

to something like

z_t_mean = beta1*z_prev1 + beta2*z_prev2
z_t = z_t_mean + tau * noise

This is what @vincentbt suggested I believe.

1 Like

@fehiepsi In that case yes, I did check and the output is indeed equivalent.

I also implemented the model as you suggested and it indeed starts working better. However, the MCMC chains seem to achieve a lower number of effective samples compared to the “vanilla” implementation (therefore requiring some tuning of priors/step size/num_steps). Could there be a specific reason for this? And if the reason is indeed reparametrization, are there any “best practices” for making the estimation more stable?

Thanks :wink: