Hi,
I ask this question to the Numpyro developers, because they may have faced a JAX problem I encounter.
It concerns the case where one needs to use jacrev which forbid to use jax.lax.while_loop either directly (ie user is the writer) or indirectly through the jax.lax.fori_loop that is internaly guided to use while_loop.
The way it fails is due to in the inner_most scan one gets
... = jax.lax.scan(body, init, nloops) or jax.lax.scan(body, init, None, length=nloops)
with nloops
depending on a Traced object computed using arguments of the function called by the outer_most scan. Ie, JIT is automatically triggered and there is no way to use static_arguments in the process.
I was wandering if you find internally to the Numpyro development such use case and if you find a workaround?
Thanks