You are right. If `y`

is a numpy array (instead of JAX device array), then you can use `numpy.foo`

operators to get the matrix. The output will be constant when JAX compiles the program. This way, you can call it inside the model without worrying about performance.

How is scan called in this case?

I guess the best way is to create an `AR`

distribution with `sample`

and `log_prob`

methods and use it in your model

```
numpyro.sample('obs', AR(coefs=b, noise_scale=10), obs=y_obs)
```

Alternatively, you can use NumPyro scan primitive (you can mimic the GaussianHMM example there)

```
... def transition(y_recents, y_curr):
... mu = b[0] + b[1:] @ y_recents
... y_curr = numpyro.sample('obs', dist.Normal(mu, 10), obs=y_curr)
... y_recents = jnp.concatenate([y_curr[None], y_recents[:-1]])
... return y_recents, y_curr
...
... y_init = jnp.zeros(len(b))
... _, ys = scan(transition, y_init, y_obs)
```

For prediction, you can simply put scan primitive under a `condition`

handler (the same code as in time series forecasting tutorial).

I think the latter approach is more convenient and has more readable code than the first approach (which uses `stack_and_shift`

matrix) but is a bit slower (though I suspect the speed difference is small if the timeseries is small). If you are seeking for performance, you can add if/else logic:

```
def model(..., forecast=False):
if not forecasting:
# use the first approach
else:
# use the second approach
```