Fitting artificial agents to behavioral data

Hello fellow numpyro users,

I am trying to fit reinforcement learning agents to behavioral data and I was wondering whether it is somehow possible to use environments from e.g. openAI Gym within a model. Also I was wondering whether there is a way to run functions with if statements on dynamically traced jnp arrays (e.g. sampled actions).

So far I’ve seen that external functions that use sampled and traced jnp arrays as their argument have to be compileable with jit. I have tried using the @partial(jit, static_argnums=(1,)) decorator but this leads to issues with the argument of the function not being hashable. Is there a way around this?

A simple example would be a case in which I have a sampled action index inside a traced jnp.array and I want to pass this to an environment that returns a value depending on the action index it got as an argument:

action = numpyro.sample(f"obs{i}_{j}", dist.CategoricalLogits(q_values), obs=data[i,j])
reward = environment_step(action)


def environment_step(action):
    if action == jnp.array(1.):
        return 1.0 
        return -1.0

Is there a way to make something like this work? Ideally I would also like to be able to pass a sampled action to any external black box function (similar to how the environment is inaccessible to an agent in deep reinforcement learning and the gradient descent involved in its training). Is that possible without having to rewrite the environment such that it is fully jit compileable with jax?

Hi @smp, if you only want to be jit-compiled (i.e. not include grad, vmap, pmap transformations), then you might want to use host_callback. Some of its examples in numpyro:

Thank you very much for the hint @fehiepsi, I tried to wrap the environment_step function with a host_callback now and this returns an empty array that prints as Traced<ShapedArray(float32[]):JaxprTrace(level=2/0)> and has a length of 0. For the result_shape argument in I used jax.ShapeDtypeStruct((), jnp.float32).

I’ve tried changing the returns of environment_step to be of the same type specified in result_shape and the returned array was still empty. There is probably still something that I’m missing here about the returns and data types.

edit: It seems that the rest of the code is parallelized, while the output of the function in the host_callback call is a single value now. Is there another way to allow using functions that call normal python if statements on dynamically traced arguments?

Hi @smp,

I have been using both Pyro and Numpyro for fitting computational models of behavior to experimental data. Note that you do not need to sample actions to estimate the parameters, as you are just using actions to estimate model likelihood. In your case both actions and outcomes of those action will be fixed by the behavior of participants, hence it is sufficient to have mapping from outcomes to action probabilities (logits).

For sampling actions from a model you can make separate Distribution (BehaviouralDistribution) and define the sampling process inside either sample method or another special method which would return both actions and outcomes.

I do not see an advantage to run environment inside the generative model your are using for parameter estimate. It will just slow down the inference.