I think the current provenance approach does not work with jax control flow. I think you can use
return jnp.where(x < 0, R0 + v*x, R0 + v*x - k*(1.-jnp.exp(-x/tau)))
instead. Here is a reference for plate
: Plate notation - Wikipedia
I think the current provenance approach does not work with jax control flow. I think you can use
return jnp.where(x < 0, R0 + v*x, R0 + v*x - k*(1.-jnp.exp(-x/tau)))
instead. Here is a reference for plate
: Plate notation - Wikipedia