UnexpectedTracerError because of numpyro.sample call within a function

Dear fellow numpyro users (and developers),

I’m trying to have a jax.lax.fori_loop that makes sampling actions for a number of agents faster and more importantly I want to make the compilation faster (the Python kernel always crashed after 10 minutes of trying to compile when I just use Python for loops).

Is there a way to sample (with obs argument set) inside a fori_loop? The motivation is that this sampling will be done for many time-steps and many agents over two for loops. My simulator code works fine when I’m running it outside of numpyro (with jit compile and all). When I run it in a numpyro model (with mcmc) though, it throws UnexpectedTracerError for the numpyro.sample call in the fori_loop, meaning that numpyro leaks the sampled actions (which makes sense and is necessary I think). Here’s the code of the function I’m trying to call inside the numpyro model:

@jit
def sample_next_actions(time_step, q_values, data):

    actions_t = jnp.empty(data.shape[0]) # one action for n agents

    # i as in the ith agent's action to be sampled at a time-step
    def body_fn(i, actions_t):
    
         # Problem: numpyro will leak the sampling result.
         action = numpyro.sample("action{time_step}{i}", 
                                dist.Categorical(logits =q_values[i,:]),
                                obs = data[time_step,i]) 
    
        return actions_t.at[i].set(action)

    actions_t = lax.fori_loop(0,actions_t.size, body_fn, actions_t)

    return actions_t

And here’s my numpyro model (with which I want to estimate learning rate distributions for the agents) in which this function is called inside another fori_loop:

def model(N_agents, N_decisions, data):

    with numpyro.plate("learning_rates", N_agents):
        learning_rates = numpyro.sample(f"learning_rate{i}", dist.Normal(loc=0, scale=1))

        # actions array
    actions = jnp.zeros((N_agents, N_decisions))

    # initial Q-values
    q_values = jnp.zeros((N_agents,2))

    initial_val = (actions, q_values)

    # i as in the ith decision/time-step for all agents)
    def body_fn(i, vals):
        actions, q_values = vals
    
        new_actions = sample_next_actions(i, q_values, data)
        rewards = get_rewards(new_actions)
        q_values = update_q_vals(q_values, new_actions, learning_rates, rewards)

        return (actions.at[:,i].set(new_actions), q_values)

    sampled_actions,q_vals = lax.fori_loop(0, N_agents, body_fn, initial_val)

Is there a way to do many simultaneous samples (for each agent) within an efficient for loop over many time-steps? If that’s not the case, is there another way to speed up the compilation and avoid the kernels from crashing? It worked fine for few agents and few decision steps, but compilation time increased exponentially with the number of agents/time-steps.

1 Like

There indeed is a way to do this and it is pretty much straight-forward. Instead of using the numpyro.sample primitive inside the for loop, I now use normal sampling and simultaneously compute the log-probabilities for the sequence and store them in an array. Then outside of the for loop, I add the sum of the log-probabilities for each sequence to the model, using the numpyro.factor primitive. Both compilation and sampling are much quicker now and it turned out to be a lot easier than what I had imagined.

2 Likes