You are right. If
y is a numpy array (instead of JAX device array), then you can use
numpy.foo operators to get the matrix. The output will be constant when JAX compiles the program. This way, you can call it inside the model without worrying about performance.
How is scan called in this case?
I guess the best way is to create an
AR distribution with
log_prob methods and use it in your model
numpyro.sample('obs', AR(coefs=b, noise_scale=10), obs=y_obs)
Alternatively, you can use NumPyro scan primitive (you can mimic the GaussianHMM example there)
... def transition(y_recents, y_curr):
... mu = b + b[1:] @ y_recents
... y_curr = numpyro.sample('obs', dist.Normal(mu, 10), obs=y_curr)
... y_recents = jnp.concatenate([y_curr[None], y_recents[:-1]])
... return y_recents, y_curr
... y_init = jnp.zeros(len(b))
... _, ys = scan(transition, y_init, y_obs)
For prediction, you can simply put scan primitive under a
condition handler (the same code as in time series forecasting tutorial).
I think the latter approach is more convenient and has more readable code than the first approach (which uses
stack_and_shift matrix) but is a bit slower (though I suspect the speed difference is small if the timeseries is small). If you are seeking for performance, you can add if/else logic:
def model(..., forecast=False):
if not forecasting:
# use the first approach
# use the second approach