Choosing array in transition function in numpyro

Hello! We are trying the tutorial: Time Series Forecasting — NumPyro documentation
We would like to use this transition function:

Is there a way to use an array instead of N in the line: y_t = jnp.where(t>=N, exp_val, y[t]) ?

We would like to choose the array, for example: y_t = jnp.where(t in [82,83,…], exp_val, y[t]) ?

However when we try to use the array above, instead of N, there is an error message:
You can use transformation parameters such as static_argnums for jit to avoid tracing particular arguments of transformed functions, though at the cost of more recompiles.
See JAX Frequently Asked Questions (FAQ) — JAX documentation for more information.

  1. How do we use static_argnums for this example?

  2. Also, is there a more efficient way to solve this without using static_argnums?

Hi @pookie, I think t in [82,83,...] can be translated to jnp.any(t == jnp.array([82,83,...])).

1 Like

thank you @fehiepsi ! yes, that works! :grinning_face_with_smiling_eyes: