Hello! We are trying the tutorial: Time Series Forecasting — NumPyro documentation
We would like to use this transition function:
Is there a way to use an array instead of N in the line: y_t = jnp.where(t>=N, exp_val, y[t]) ?
We would like to choose the array, for example: y_t = jnp.where(t in [82,83,…], exp_val, y[t]) ?
However when we try to use the array above, instead of N, there is an error message:
You can use transformation parameters such as static_argnums
for jit
to avoid tracing particular arguments of transformed functions, though at the cost of more recompiles.
See JAX Frequently Asked Questions (FAQ) — JAX documentation for more information.

How do we use static_argnums for this example?

Also, is there a more efficient way to solve this without using static_argnums?