Unexpected tracer error involving PRNGKey when using block handler on inference model

Hello,

I am currently working on a project where I am facing an unusual error regarding the block handler and seeding in Numpyro. Below is a simplified version of my code that recreates the error and the error message.

It may be easier to read through my code, however here is an overview of what I am trying to achieve:

I am trying to perform inference by using this two step MCMC approach. I first try to generate good samples from the beta and intercept distributions by running MCMC on the prior_model and conditioning the prior_model using the covariate and pred data.

Next I run MCMC again but this time on the posterior model to infer the covariate. I do that by conditioning on only the “pred” data and randomly sampling from the beta and intercept prior_model samples.

The issue is that randomly sampling from those prior samples is a bit harder than I realized. I want to make sure that the indices are random during MCMC so I need to make sure that beta_ix changes even after jax compiles it. My solution so far is to create another variable named beta_ix as a categorical distribution that will index into those samples randomly. I then use the block handler on that variable since I don’t want to perform inference on it. The issue with that is there is a bunch of weird effects with the seeding and I’m pretty sure that the trace error has to do with the key because I think the dimensions and dtype of the lost tracer are the same as a PRNGKey.

Thank you in advance! Please let me know if I can clarify anything. Also any ideas are welcome I’d love to know if there is a much easier way of doing this.

import funsor
import jax
from jax import random
import jax.numpy as np
import matplotlib.pyplot as plt
import matplotlib
import numpy as onp

import numpyro
import numpyro.distributions as dist
from numpyro.handlers import trace, seed, condition, block
from numpyro.infer import MCMC, NUTS, Predictive

numpyro.set_platform('cpu')
numpyro.set_host_device_count(4)

# Prior Model for fitting the coefficients
def prior_model(num_samples):
    #defining coefficients in model
    intercept = numpyro.sample('intercept', dist.Normal(0,1))
    beta = numpyro.sample('beta', dist.Normal(0,1))
    
    with numpyro.plate('data', num_samples):
        covariate = numpyro.sample('covariate', dist.Normal(0,3))
        logpi = intercept + covariate*beta
        pred = numpyro.sample('pred', dist.Bernoulli(logits=logpi))

    return intercept, covariate, beta, pred

#posterior model for performing inference on the "covariate" random variable 
def posterior_model(num_samples,post_samples):
    
    #defining beta_ix 
    #purpose is to index into the posterior samples randomly during MCMC
    num_choices = len(post_samples['beta'])
    probs = np.ones(num_choices)/num_choices 
    #categorical distribution where probability of drawing a random sample is the same for all prior model samples
    beta_ix = numpyro.sample('beta_ix',dist.Categorical(probs=probs)).astype(np.int32)

    intercept = post_samples['intercept'][beta_ix]
    beta = post_samples['beta'][beta_ix]

    with numpyro.plate('data', num_samples):
        covariate = numpyro.sample('covariate', dist.Normal(0,1))
        logpi = intercept + covariate*beta
        pred = numpyro.sample('pred', dist.Bernoulli(logits=logpi))

    return intercept, covariate, beta, pred, beta_ix

#fake data
num_samples = 10
C = 1.3 + 0.5 * onp.random.randn(num_samples)
P = onp.random.randint(0,2, (num_samples,))

#observing covariate and pred and running MCMC for fitted beta and intercept
model_cond = numpyro.handlers.condition(prior_model, data={'covariate': C, 'pred': P})
nuts_kernel = NUTS(model_cond)
mcmc = MCMC(nuts_kernel, num_samples=5, num_warmup=100, num_chains=4)
rng_key = random.PRNGKey(0)
mcmc.run(rng_key, num_samples)

prior_samples=mcmc.get_samples()
prior_samples.keys()

rng_key = random.PRNGKey(0)
data = {'pred':P}

#seed posterior model
seeded = seed(posterior_model, rng_key)   

#blocking beta_ix so that the model does not try to perform inference on beta_ix
pp_model_blocked = block(seeded, hide=['beta_ix'])

#condition posterior model
model_cond_post = numpyro.handlers.condition(
  pp_model_blocked, 
  data=data
  )

#run MCMC
nuts_kernel_post = NUTS(model_cond_post)
mcmc_post = MCMC(nuts_kernel_post, num_samples=15, num_warmup=100, num_chains=4)
mcmc_post.run(rng_key, num_samples, prior_samples)
---------------------------------------------------------------------------
UnexpectedTracerError                     Traceback (most recent call last)
Input In [9], in <cell line: 19>()
     17 nuts_kernel_post = NUTS(model_cond_post)
     18 mcmc_post = MCMC(nuts_kernel_post, num_samples=15, num_warmup=100, num_chains=4)
---> 19 mcmc_post.run(rng_key, num_samples, prior_samples)
     20 mcmc_post.run(jax.random.PRNGKey(1), num_samples, prior_samples)

File ~/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/infer/mcmc.py:599, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs)
    597     states, last_state = _laxmap(partial_map_fn, map_args)
    598 elif self.chain_method == "parallel":
--> 599     states, last_state = pmap(partial_map_fn)(map_args)
    600 else:
    601     assert self.chain_method == "vectorized"

    [... skipping hidden 17 frame]

File ~/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/infer/mcmc.py:404, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields)
    398 collection_size = self._collection_params["collection_size"]
    399 collection_size = (
    400     collection_size
    401     if collection_size is None
    402     else collection_size // self.thinning
    403 )
--> 404 collect_vals = fori_collect(
    405     lower_idx,
    406     upper_idx,
    407     sample_fn,
    408     init_val,
    409     transform=_collect_fn(collect_fields),
    410     progbar=self.progress_bar,
    411     return_last_val=True,
    412     thinning=self.thinning,
    413     collection_size=collection_size,
    414     progbar_desc=partial(_get_progbar_desc_str, lower_idx, phase),
    415     diagnostics_fn=diagnostics,
    416     num_chains=self.num_chains if self.chain_method == "parallel" else 1,
    417 )
    418 states, last_val = collect_vals
    419 # Get first argument of type `HMCState`

File ~/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/util.py:344, in fori_collect(lower, upper, body_fun, init_val, transform, progbar, return_last_val, collection_size, thinning, **progbar_opts)
    342     progress_bar_fori_loop = progress_bar_factory(upper, num_chains)
    343     _body_fn_pbar = progress_bar_fori_loop(_body_fn)
--> 344     last_val, collection, _, _ = fori_loop(
    345         0, upper, _body_fn_pbar, (init_val, collection, start_idx, thinning)
    346     )
    347 else:
    348     diagnostics_fn = progbar_opts.pop("diagnostics_fn", None)

File ~/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/util.py:141, in fori_loop(lower, upper, body_fun, init_val)
    139     return val
    140 else:
--> 141     return lax.fori_loop(lower, upper, body_fun, init_val)

    [... skipping hidden 16 frame]

File ~/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/util.py:248, in progress_bar_factory.<locals>.progress_bar_fori_loop.<locals>.wrapper_progress_bar(i, vals)
    247 def wrapper_progress_bar(i, vals):
--> 248     result = func(i, vals)
    249     _update_progress_bar(i + 1)
    250     return result

File ~/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/util.py:323, in fori_collect.<locals>._body_fn(i, vals)
    320 @cached_by(fori_collect, body_fun, transform)
    321 def _body_fn(i, vals):
    322     val, collection, start_idx, thinning = vals
--> 323     val = body_fun(val)
    324     idx = (i - start_idx) // thinning
    325     collection = cond(
    326         idx >= 0,
    327         collection,
   (...)
    330         identity,
    331     )

File ~/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/infer/mcmc.py:172, in _sample_fn_nojit_args(state, sampler, args, kwargs)
    170 def _sample_fn_nojit_args(state, sampler, args, kwargs):
    171     # state is a tuple of size 1 - containing HMCState
--> 172     return (sampler.sample(state[0], args, kwargs),)

File ~/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/infer/hmc.py:771, in HMC.sample(self, state, model_args, model_kwargs)
    761 def sample(self, state, model_args, model_kwargs):
    762     """
    763     Run HMC from the given :data:`~numpyro.infer.hmc.HMCState` and return the resulting
    764     :data:`~numpyro.infer.hmc.HMCState`.
   (...)
    769     :return: Next `state` after running HMC.
    770     """
--> 771     return self._sample_fn(state, model_args, model_kwargs)

File ~/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/infer/hmc.py:467, in hmc.<locals>.sample_kernel(hmc_state, model_args, model_kwargs)
    463 else:
    464     hmc_length_args = (
    465         jnp.where(hmc_state.i < wa_steps, max_treedepth[0], max_treedepth[1]),
    466     )
--> 467 vv_state, energy, num_steps, accept_prob, diverging = _next(
    468     hmc_state.adapt_state.step_size,
    469     hmc_state.adapt_state.inverse_mass_matrix,
    470     vv_state,
    471     model_args,
    472     model_kwargs,
    473     rng_key_transition,
    474     *hmc_length_args,
    475 )
    476 # not update adapt_state after warmup phase
    477 adapt_state = cond(
    478     hmc_state.i < wa_steps,
    479     (hmc_state.i, accept_prob, vv_state, hmc_state.adapt_state),
   (...)
    482     identity,
    483 )

File ~/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/infer/hmc.py:407, in hmc.<locals>._nuts_next(step_size, inverse_mass_matrix, vv_state, model_args, model_kwargs, rng_key, max_treedepth_current)
    404     pe_fn = potential_fn_gen(*model_args, **model_kwargs)
    405     _, vv_update = velocity_verlet(pe_fn, kinetic_fn, forward_mode_ad)
--> 407 binary_tree = build_tree(
    408     vv_update,
    409     kinetic_fn,
    410     vv_state,
    411     inverse_mass_matrix,
    412     step_size,
    413     rng_key,
    414     max_delta_energy=max_delta_energy,
    415     max_tree_depth=(max_treedepth_current, max(max_treedepth)),
    416 )
    417 accept_prob = binary_tree.sum_accept_probs / binary_tree.num_proposals
    418 num_steps = binary_tree.num_proposals

File ~/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/infer/hmc_util.py:1177, in build_tree(verlet_update, kinetic_fn, verlet_state, inverse_mass_matrix, step_size, rng_key, max_delta_energy, max_tree_depth)
   1174     return tree, key
   1176 state = (tree, rng_key)
-> 1177 tree, _ = while_loop(_cond_fn, _body_fn, state)
   1178 return tree

File ~/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/util.py:131, in while_loop(cond_fun, body_fun, init_val)
    129     return val
    130 else:
--> 131     return lax.while_loop(cond_fun, body_fun, init_val)

    [... skipping hidden 13 frame]

File ~/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/infer/hmc_util.py:1161, in build_tree.<locals>._body_fn(state)
   1159 key, direction_key, doubling_key = random.split(key, 3)
   1160 going_right = random.bernoulli(direction_key)
-> 1161 tree = _double_tree(
   1162     tree,
   1163     verlet_update,
   1164     kinetic_fn,
   1165     inverse_mass_matrix,
   1166     step_size,
   1167     going_right,
   1168     doubling_key,
   1169     energy_current,
   1170     max_delta_energy,
   1171     r_ckpts,
   1172     r_sum_ckpts,
   1173 )
   1174 return tree, key

File ~/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/infer/hmc_util.py:916, in _double_tree(current_tree, vv_update, kinetic_fn, inverse_mass_matrix, step_size, going_right, rng_key, energy_current, max_delta_energy, r_ckpts, r_sum_ckpts)
    901 def _double_tree(
    902     current_tree,
    903     vv_update,
   (...)
    912     r_sum_ckpts,
    913 ):
    914     key, transition_key = random.split(rng_key)
--> 916     new_tree = _iterative_build_subtree(
    917         current_tree,
    918         vv_update,
    919         kinetic_fn,
    920         inverse_mass_matrix,
    921         step_size,
    922         going_right,
    923         key,
    924         energy_current,
    925         max_delta_energy,
    926         r_ckpts,
    927         r_sum_ckpts,
    928     )
    930     return _combine_tree(
    931         current_tree, new_tree, inverse_mass_matrix, going_right, transition_key, True
    932     )

File ~/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/infer/hmc_util.py:1061, in _iterative_build_subtree(prototype_tree, vv_update, kinetic_fn, inverse_mass_matrix, step_size, going_right, rng_key, energy_current, max_delta_energy, r_ckpts, r_sum_ckpts)
   1057     return new_tree, turning, r_ckpts, r_sum_ckpts, rng_key
   1059 basetree = prototype_tree._replace(num_proposals=0)
-> 1061 tree, turning, _, _, _ = while_loop(
   1062     _cond_fn, _body_fn, (basetree, False, r_ckpts, r_sum_ckpts, rng_key)
   1063 )
   1064 # update depth and turning condition
   1065 return TreeInfo(
   1066     tree.z_left,
   1067     tree.r_left,
   (...)
   1082     tree.num_proposals,
   1083 )

File ~/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/util.py:131, in while_loop(cond_fun, body_fun, init_val)
    129     return val
    130 else:
--> 131     return lax.while_loop(cond_fun, body_fun, init_val)

    [... skipping hidden 13 frame]

File ~/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/infer/hmc_util.py:1006, in _iterative_build_subtree.<locals>._body_fn(state)
   1004 # If we are going to the right, start from the right leaf of the current tree.
   1005 z, r, z_grad = _get_leaf(current_tree, going_right)
-> 1006 new_leaf = _build_basetree(
   1007     vv_update,
   1008     kinetic_fn,
   1009     z,
   1010     r,
   1011     z_grad,
   1012     inverse_mass_matrix,
   1013     step_size,
   1014     going_right,
   1015     energy_current,
   1016     max_delta_energy,
   1017 )
   1018 new_tree = cond(
   1019     current_tree.num_proposals == 0,
   1020     new_leaf,
   (...)
   1029     lambda x: _combine_tree(*x, False),
   1030 )
   1032 leaf_idx = current_tree.num_proposals

File ~/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/infer/hmc_util.py:858, in _build_basetree(vv_update, kinetic_fn, z, r, z_grad, inverse_mass_matrix, step_size, going_right, energy_current, max_delta_energy)
    845 def _build_basetree(
    846     vv_update,
    847     kinetic_fn,
   (...)
    855     max_delta_energy,
    856 ):
    857     step_size = jnp.where(going_right, step_size, -step_size)
--> 858     z_new, r_new, potential_energy_new, z_new_grad = vv_update(
    859         step_size, inverse_mass_matrix, (z, r, energy_current, z_grad)
    860     )
    862     energy_new = potential_energy_new + kinetic_fn(inverse_mass_matrix, r_new)
    863     delta_energy = energy_new - energy_current

File ~/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/infer/hmc_util.py:297, in velocity_verlet.<locals>.update_fn(step_size, inverse_mass_matrix, state)
    295 r_grad = _kinetic_grad(kinetic_fn, inverse_mass_matrix, r)
    296 z = tree_map(lambda z, r_grad: z + step_size * r_grad, z, r_grad)  # z(n+1)
--> 297 potential_energy, z_grad = _value_and_grad(
    298     potential_fn, z, forward_mode_differentiation
    299 )
    300 r = tree_map(
    301     lambda r, z_grad: r - 0.5 * step_size * z_grad, r, z_grad
    302 )  # r(n+1)
    303 return IntegratorState(z, r, potential_energy, z_grad)

File ~/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/infer/hmc_util.py:246, in _value_and_grad(f, x, forward_mode_differentiation)
    244     return f(x), jacfwd(f)(x)
    245 else:
--> 246     return value_and_grad(f)(x)

    [... skipping hidden 8 frame]

File ~/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/infer/util.py:248, in potential_energy(model, model_args, model_kwargs, params, enum)
    244 substituted_model = substitute(
    245     model, substitute_fn=partial(_unconstrain_reparam, params)
    246 )
    247 # no param is needed for log_density computation because we already substitute
--> 248 log_joint, model_trace = log_density_(
    249     substituted_model, model_args, model_kwargs, {}
    250 )
    251 return -log_joint

File ~/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/infer/util.py:62, in log_density(model, model_args, model_kwargs, params)
     50 """
     51 (EXPERIMENTAL INTERFACE) Computes log of joint density for the model given
     52 latent values ``params``.
   (...)
     59 :return: log of joint density and a corresponding model trace
     60 """
     61 model = substitute(model, data=params)
---> 62 model_trace = trace(model).get_trace(*model_args, **model_kwargs)
     63 log_joint = jnp.zeros(())
     64 for site in model_trace.values():

File ~/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/handlers.py:171, in trace.get_trace(self, *args, **kwargs)
    163 def get_trace(self, *args, **kwargs):
    164     """
    165     Run the wrapped callable and return the recorded trace.
    166 
   (...)
    169     :return: `OrderedDict` containing the execution trace.
    170     """
--> 171     self(*args, **kwargs)
    172     return self.trace

File ~/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
    103     return self
    104 with self:
--> 105     return self.fn(*args, **kwargs)

File ~/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
    103     return self
    104 with self:
--> 105     return self.fn(*args, **kwargs)

    [... skipping similar frames: Messenger.__call__ at line 105 (4 times)]

File ~/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
    103     return self
    104 with self:
--> 105     return self.fn(*args, **kwargs)

Input In [2], in posterior_model(num_samples, post_samples)
     20 probs = np.ones(num_choices)/num_choices 
     21 #categorical distribution where probability of drawing a random sample is the same for all prior model samples
---> 22 beta_ix = numpyro.sample('beta_ix',dist.Categorical(probs=probs)).astype(np.int32)
     24 intercept = post_samples['intercept'][beta_ix]
     25 beta = post_samples['beta'][beta_ix]

File ~/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/primitives.py:219, in sample(name, fn, obs, rng_key, sample_shape, infer, obs_mask)
    204 initial_msg = {
    205     "type": "sample",
    206     "name": name,
   (...)
    215     "infer": {} if infer is None else infer,
    216 }
    218 # ...and use apply_stack to send it to the Messengers
--> 219 msg = apply_stack(initial_msg)
    220 return msg["value"]

File ~/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/primitives.py:47, in apply_stack(msg)
     45 pointer = 0
     46 for pointer, handler in enumerate(reversed(_PYRO_STACK)):
---> 47     handler.process_message(msg)
     48     # When a Messenger sets the "stop" field of a message,
     49     # it prevents any Messengers above it on the stack from being applied.
     50     if msg.get("stop"):

File ~/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/handlers.py:720, in seed.process_message(self, msg)
    717 if msg["value"] is not None:
    718     # no need to create a new key when value is available
    719     return
--> 720 self.rng_key, rng_key_sample = random.split(self.rng_key)
    721 msg["kwargs"]["rng_key"] = rng_key_sample

File ~/.conda/envs/pmi/lib/python3.10/site-packages/jax/_src/random.py:194, in split(key, num)
    183 """Splits a PRNG key into `num` new keys by adding a leading axis.
    184 
    185 Args:
   (...)
    191   An array-like object of `num` new PRNG keys.
    192 """
    193 key, wrapped = _check_prng_key(key)
--> 194 return _return_prng_keys(wrapped, _split(key, num))

File ~/.conda/envs/pmi/lib/python3.10/site-packages/jax/_src/random.py:180, in _split(key, num)
    177 def _split(key: KeyArray, num: int = 2) -> KeyArray:
    178   # Alternative to split() to use within random samplers.
    179   # TODO(frostig): remove and use split() once we always enable_custom_prng
--> 180   return key._split(num)

File ~/.conda/envs/pmi/lib/python3.10/site-packages/jax/_src/prng.py:203, in PRNGKeyArray._split(self, num)
    202 def _split(self, num: int) -> 'PRNGKeyArray':
--> 203   return PRNGKeyArray(self.impl, self.impl.split(self._keys, num))

File ~/.conda/envs/pmi/lib/python3.10/site-packages/jax/_src/prng.py:474, in threefry_split(key, num)
    473 def threefry_split(key: jnp.ndarray, num: int) -> jnp.ndarray:
--> 474   return _threefry_split(key, int(num))

    [... skipping hidden 5 frame]

File ~/.conda/envs/pmi/lib/python3.10/site-packages/jax/interpreters/partial_eval.py:1351, in DynamicJaxprTracer._assert_live(self)
   1349 def _assert_live(self) -> None:
   1350   if not self._trace.main.jaxpr_stack:  # type: ignore
-> 1351     raise core.escaped_tracer_error(self, None)

UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with shape (2,) and dtype uint32 to escape.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was body_fn at /home/nnisbet/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/infer/util.py:314 traced for while_loop.
------------------------------
The leaked intermediate value was created on line /home/nnisbet/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/handlers.py:716 (process_message). 
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
/home/nnisbet/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/primitives.py:102 (__call__)
/local_scratch/pbs.5250180.pbs02/ipykernel_1291431/3692405016.py:25 (posterior_model)
/home/nnisbet/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/primitives.py:186 (sample)
/home/nnisbet/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/primitives.py:46 (apply_stack)
/home/nnisbet/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/handlers.py:716 (process_message)
------------------------------

To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError

Strangely enough this problem doesn’t occur when I only run one chain on the second MCMC. I may be wrong but I think there is some bug occurring that involves the block handler and transferring the PRNGKey between chains.

I guess jax does not allow us to store traced rng_key (under pmap) in the seeded instance. You can simply do

def model(..., rng_beta_ix):
    beta_ix = dist.(...).sample(rng_beta_ix)
    ...
1 Like

Thank you that worked! I’d also love to know if you have any criticisms of my method for performing posterior inference. I know that using MCMC twice works but it feels a little hacky.