AR(p) in numpyro?

Hello, we are trying to implement a simple AR(p) in numpyro.

  1. could you please provide some code example, tutorial?

  2. For AR(p) model, we have a double for loop: over t and over p:
    y(t) = b0 + b1*y(t-1) + b2*y(t-2) + ... + bp*y(t-p)

Then the model would be (maybe?):

b  = numpyro.sample('b', dist.Normal(0., 10.).expand([1+K]))
mu = b[0]
for t in range(len(y_obs)):
    if t > 0:  
        for p in range(1, K+1):                          
            mu += b[p] * y_obs[t-p]           
    numpyro.sample('obs_{}'.format(t), dist.Normal(mu, 10), obs=y_obs[t])

How can we implement it efficiently in numpyro, to avoid the double for loop (over t and over p)?

1 Like

Hi @pookie, for likelihood, you can do

def stack_and_shift(y):
    # convert y to the matrix
    #    y1 y2 ... yp
    #    y2 y3 ... yp+1
    #    y(t-p) .... yt-1
mu = stack_and_shift(y) @ b[1:] + b[0]
numpyro.sample("obs", dist.Normal(mu, 10), obs=y_obs)

For sampling, you can draw a bunch of gaussian noise (each of one time step), and using scan with the formula

 y(t) = b0 + b1*y(t-1) + b2*y(t-2) + ... + bp*y(t-p) + noise

to get sampled values (the carry vector will be something like y_last_p_values with shape p, and at each step, applying, b[1:]) + b[0]).

If you are using GPU and want to sample as much parallel as possible, it might be better to consider AR( p ) as a state-space model by creating the transition matrix

transition_matrix = [0 1 0 0 ... 0,
                     0 0 1 0 ... 0,
                     0 0 0 1 ... 0,
                    b1 b2 b3 ... bp]

and applying parallel-scan, similar to how GaussianHMM.rsample is implemented. (I don’t recommend following this approach though - mainly because it requires more math and is more complicated to implement)


Thank you @fehiepsi. We implemented stack_and_shift(). It looks like stack_and_shift() should be called before starting the inference algo, since it only includes y values (it does not need to be inside model)?
Then the y matrix is passed to the model, to compute mu?

How is scan called in this case?

You are right. If y is a numpy array (instead of JAX device array), then you can use operators to get the matrix. The output will be constant when JAX compiles the program. This way, you can call it inside the model without worrying about performance. :slight_smile:

How is scan called in this case?

I guess the best way is to create an AR distribution with sample and log_prob methods and use it in your model

numpyro.sample('obs', AR(coefs=b, noise_scale=10), obs=y_obs)

Alternatively, you can use NumPyro scan primitive (you can mimic the GaussianHMM example there)

...     def transition(y_recents, y_curr):
...         mu = b[0] + b[1:] @ y_recents
...         y_curr = numpyro.sample('obs', dist.Normal(mu, 10), obs=y_curr)
...         y_recents = jnp.concatenate([y_curr[None], y_recents[:-1]])
...         return y_recents, y_curr
...     y_init = jnp.zeros(len(b))
...     _, ys = scan(transition, y_init, y_obs)

For prediction, you can simply put scan primitive under a condition handler (the same code as in time series forecasting tutorial).

I think the latter approach is more convenient and has more readable code than the first approach (which uses stack_and_shift matrix) but is a bit slower (though I suspect the speed difference is small if the timeseries is small). If you are seeking for performance, you can add if/else logic:

def model(..., forecast=False):
    if not forecasting:
        # use the first approach
        # use the second approach
1 Like

thank you @fehiepsi. :smile: This is very useful.