RL model with numpyro. Assignment issue

Hi, I am a beginner to JAX and working on the paramter inference of a RL model with numpyro.

The framework is roughly as bellow:

def model(observations, a=None, b=None):
    # prior
    a = numpyro.sample("a", dist.Normal(1, 0))
    b = numpyro.sample("b", dist.Normal(1, 0))

    # RL model
    rl = RLModel(a, b)
    rl.value_iteration()
    simulations = rl.simulate()
    mu = numpyro.deterministic("mu", jnp.mean(simulations))

    # sampling
    numpyro.sample("obs", dist.Normal(mu, 1), obs=observations)

I want to run a value iteration function to get the optimal policy, and then use the policy for simulation. But the input parameter(here a or b) is Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, which can not be used as argument in function, as the element in array cannot be updated by a “trace”. Such as:

def foo(a, b):
    v = np.zeros((10, 10))
    v[0, 0] = a + b
    return v

I got the Error:

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape float32[].
The problem arose with the `float` function. If trying to convert the data type of a value, try using `x.astype(float)` or `jnp.array(x, float)` instead.
The error occurred while tracing the function get_trace at /home/w/miniconda3/envs/abcd_env/lib/python3.9/site-packages/numpyro/infer/inspect.py:304 for jit. This value became a tracer due to JAX operations on these lines:

I would greatly appreciate any advice to the assignment issue, or references to learn more about it. Thanks a lot!