Nested jax.lax.scan loops

I ask this question to the Numpyro developers, because they may have faced a JAX problem I encounter.

It concerns the case where one needs to use jacrev which forbid to use jax.lax.while_loop either directly (ie user is the writer) or indirectly through the jax.lax.fori_loop that is internaly guided to use while_loop.

The way it fails is due to in the inner_most scan one gets

... = jax.lax.scan(body, init, nloops)   or  jax.lax.scan(body, init, None, length=nloops)

with nloops depending on a Traced object computed using arguments of the function called by the outer_most scan. Ie, JIT is automatically triggered and there is no way to use static_arguments in the process.

I was wandering if you find internally to the Numpyro development such use case and if you find a workaround?

Could you add some reproducible code? It would be helpful to understand the issue that you have.

Well, @fehiepsi it is not related to Numpyro but if you are interested I have posted a use case in the JAX discussion forum HERE. In fact it is a use-case from a recoding of C++ code in JAX where I need at the end of the day that jacrev is ok, so “while_loop” are forbidden either in a direct manner nor through fori_loop that fall into the use of while_loop. In case you have a nice workaround you will be the one.