You are right. If y
is a numpy array (instead of JAX device array), then you can use numpy.foo
operators to get the matrix. The output will be constant when JAX compiles the program. This way, you can call it inside the model without worrying about performance. 
How is scan called in this case?
I guess the best way is to create an AR
distribution with sample
and log_prob
methods and use it in your model
numpyro.sample('obs', AR(coefs=b, noise_scale=10), obs=y_obs)
Alternatively, you can use NumPyro scan primitive (you can mimic the GaussianHMM example there)
... def transition(y_recents, y_curr):
... mu = b[0] + b[1:] @ y_recents
... y_curr = numpyro.sample('obs', dist.Normal(mu, 10), obs=y_curr)
... y_recents = jnp.concatenate([y_curr[None], y_recents[:-1]])
... return y_recents, y_curr
...
... y_init = jnp.zeros(len(b))
... _, ys = scan(transition, y_init, y_obs)
For prediction, you can simply put scan primitive under a condition
handler (the same code as in time series forecasting tutorial).
I think the latter approach is more convenient and has more readable code than the first approach (which uses stack_and_shift
matrix) but is a bit slower (though I suspect the speed difference is small if the timeseries is small). If you are seeking for performance, you can add if/else logic:
def model(..., forecast=False):
if not forecasting:
# use the first approach
else:
# use the second approach