Hi, I am a beginner to JAX and working on the paramter inference of a RL model with numpyro.
The framework is roughly as bellow:
def model(observations, a=None, b=None):
# prior
a = numpyro.sample("a", dist.Normal(1, 0))
b = numpyro.sample("b", dist.Normal(1, 0))
# RL model
rl = RLModel(a, b)
rl.value_iteration()
simulations = rl.simulate()
mu = numpyro.deterministic("mu", jnp.mean(simulations))
# sampling
numpyro.sample("obs", dist.Normal(mu, 1), obs=observations)
I want to run a value iteration function to get the optimal policy, and then use the policy for simulation. But the input parameter(here a or b) is Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
, which can not be used as argument in function, as the element in array cannot be updated by a “trace”. Such as:
def foo(a, b):
v = np.zeros((10, 10))
v[0, 0] = a + b
return v
I got the Error:
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape float32[].
The problem arose with the `float` function. If trying to convert the data type of a value, try using `x.astype(float)` or `jnp.array(x, float)` instead.
The error occurred while tracing the function get_trace at /home/w/miniconda3/envs/abcd_env/lib/python3.9/site-packages/numpyro/infer/inspect.py:304 for jit. This value became a tracer due to JAX operations on these lines:
I would greatly appreciate any advice to the assignment issue, or references to learn more about it. Thanks a lot!