Hello,
I am currently working on a project where I am facing an unusual error regarding the block handler and seeding in Numpyro. Below is a simplified version of my code that recreates the error and the error message.
It may be easier to read through my code, however here is an overview of what I am trying to achieve:
I am trying to perform inference by using this two step MCMC approach. I first try to generate good samples from the beta and intercept distributions by running MCMC on the prior_model and conditioning the prior_model using the covariate and pred data.
Next I run MCMC again but this time on the posterior model to infer the covariate. I do that by conditioning on only the “pred” data and randomly sampling from the beta and intercept prior_model samples.
The issue is that randomly sampling from those prior samples is a bit harder than I realized. I want to make sure that the indices are random during MCMC so I need to make sure that beta_ix changes even after jax compiles it. My solution so far is to create another variable named beta_ix as a categorical distribution that will index into those samples randomly. I then use the block handler on that variable since I don’t want to perform inference on it. The issue with that is there is a bunch of weird effects with the seeding and I’m pretty sure that the trace error has to do with the key because I think the dimensions and dtype of the lost tracer are the same as a PRNGKey.
Thank you in advance! Please let me know if I can clarify anything. Also any ideas are welcome I’d love to know if there is a much easier way of doing this.
import funsor
import jax
from jax import random
import jax.numpy as np
import matplotlib.pyplot as plt
import matplotlib
import numpy as onp
import numpyro
import numpyro.distributions as dist
from numpyro.handlers import trace, seed, condition, block
from numpyro.infer import MCMC, NUTS, Predictive
numpyro.set_platform('cpu')
numpyro.set_host_device_count(4)
# Prior Model for fitting the coefficients
def prior_model(num_samples):
#defining coefficients in model
intercept = numpyro.sample('intercept', dist.Normal(0,1))
beta = numpyro.sample('beta', dist.Normal(0,1))
with numpyro.plate('data', num_samples):
covariate = numpyro.sample('covariate', dist.Normal(0,3))
logpi = intercept + covariate*beta
pred = numpyro.sample('pred', dist.Bernoulli(logits=logpi))
return intercept, covariate, beta, pred
#posterior model for performing inference on the "covariate" random variable
def posterior_model(num_samples,post_samples):
#defining beta_ix
#purpose is to index into the posterior samples randomly during MCMC
num_choices = len(post_samples['beta'])
probs = np.ones(num_choices)/num_choices
#categorical distribution where probability of drawing a random sample is the same for all prior model samples
beta_ix = numpyro.sample('beta_ix',dist.Categorical(probs=probs)).astype(np.int32)
intercept = post_samples['intercept'][beta_ix]
beta = post_samples['beta'][beta_ix]
with numpyro.plate('data', num_samples):
covariate = numpyro.sample('covariate', dist.Normal(0,1))
logpi = intercept + covariate*beta
pred = numpyro.sample('pred', dist.Bernoulli(logits=logpi))
return intercept, covariate, beta, pred, beta_ix
#fake data
num_samples = 10
C = 1.3 + 0.5 * onp.random.randn(num_samples)
P = onp.random.randint(0,2, (num_samples,))
#observing covariate and pred and running MCMC for fitted beta and intercept
model_cond = numpyro.handlers.condition(prior_model, data={'covariate': C, 'pred': P})
nuts_kernel = NUTS(model_cond)
mcmc = MCMC(nuts_kernel, num_samples=5, num_warmup=100, num_chains=4)
rng_key = random.PRNGKey(0)
mcmc.run(rng_key, num_samples)
prior_samples=mcmc.get_samples()
prior_samples.keys()
rng_key = random.PRNGKey(0)
data = {'pred':P}
#seed posterior model
seeded = seed(posterior_model, rng_key)
#blocking beta_ix so that the model does not try to perform inference on beta_ix
pp_model_blocked = block(seeded, hide=['beta_ix'])
#condition posterior model
model_cond_post = numpyro.handlers.condition(
pp_model_blocked,
data=data
)
#run MCMC
nuts_kernel_post = NUTS(model_cond_post)
mcmc_post = MCMC(nuts_kernel_post, num_samples=15, num_warmup=100, num_chains=4)
mcmc_post.run(rng_key, num_samples, prior_samples)
---------------------------------------------------------------------------
UnexpectedTracerError Traceback (most recent call last)
Input In [9], in <cell line: 19>()
17 nuts_kernel_post = NUTS(model_cond_post)
18 mcmc_post = MCMC(nuts_kernel_post, num_samples=15, num_warmup=100, num_chains=4)
---> 19 mcmc_post.run(rng_key, num_samples, prior_samples)
20 mcmc_post.run(jax.random.PRNGKey(1), num_samples, prior_samples)
File ~/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/infer/mcmc.py:599, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs)
597 states, last_state = _laxmap(partial_map_fn, map_args)
598 elif self.chain_method == "parallel":
--> 599 states, last_state = pmap(partial_map_fn)(map_args)
600 else:
601 assert self.chain_method == "vectorized"
[... skipping hidden 17 frame]
File ~/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/infer/mcmc.py:404, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields)
398 collection_size = self._collection_params["collection_size"]
399 collection_size = (
400 collection_size
401 if collection_size is None
402 else collection_size // self.thinning
403 )
--> 404 collect_vals = fori_collect(
405 lower_idx,
406 upper_idx,
407 sample_fn,
408 init_val,
409 transform=_collect_fn(collect_fields),
410 progbar=self.progress_bar,
411 return_last_val=True,
412 thinning=self.thinning,
413 collection_size=collection_size,
414 progbar_desc=partial(_get_progbar_desc_str, lower_idx, phase),
415 diagnostics_fn=diagnostics,
416 num_chains=self.num_chains if self.chain_method == "parallel" else 1,
417 )
418 states, last_val = collect_vals
419 # Get first argument of type `HMCState`
File ~/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/util.py:344, in fori_collect(lower, upper, body_fun, init_val, transform, progbar, return_last_val, collection_size, thinning, **progbar_opts)
342 progress_bar_fori_loop = progress_bar_factory(upper, num_chains)
343 _body_fn_pbar = progress_bar_fori_loop(_body_fn)
--> 344 last_val, collection, _, _ = fori_loop(
345 0, upper, _body_fn_pbar, (init_val, collection, start_idx, thinning)
346 )
347 else:
348 diagnostics_fn = progbar_opts.pop("diagnostics_fn", None)
File ~/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/util.py:141, in fori_loop(lower, upper, body_fun, init_val)
139 return val
140 else:
--> 141 return lax.fori_loop(lower, upper, body_fun, init_val)
[... skipping hidden 16 frame]
File ~/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/util.py:248, in progress_bar_factory.<locals>.progress_bar_fori_loop.<locals>.wrapper_progress_bar(i, vals)
247 def wrapper_progress_bar(i, vals):
--> 248 result = func(i, vals)
249 _update_progress_bar(i + 1)
250 return result
File ~/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/util.py:323, in fori_collect.<locals>._body_fn(i, vals)
320 @cached_by(fori_collect, body_fun, transform)
321 def _body_fn(i, vals):
322 val, collection, start_idx, thinning = vals
--> 323 val = body_fun(val)
324 idx = (i - start_idx) // thinning
325 collection = cond(
326 idx >= 0,
327 collection,
(...)
330 identity,
331 )
File ~/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/infer/mcmc.py:172, in _sample_fn_nojit_args(state, sampler, args, kwargs)
170 def _sample_fn_nojit_args(state, sampler, args, kwargs):
171 # state is a tuple of size 1 - containing HMCState
--> 172 return (sampler.sample(state[0], args, kwargs),)
File ~/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/infer/hmc.py:771, in HMC.sample(self, state, model_args, model_kwargs)
761 def sample(self, state, model_args, model_kwargs):
762 """
763 Run HMC from the given :data:`~numpyro.infer.hmc.HMCState` and return the resulting
764 :data:`~numpyro.infer.hmc.HMCState`.
(...)
769 :return: Next `state` after running HMC.
770 """
--> 771 return self._sample_fn(state, model_args, model_kwargs)
File ~/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/infer/hmc.py:467, in hmc.<locals>.sample_kernel(hmc_state, model_args, model_kwargs)
463 else:
464 hmc_length_args = (
465 jnp.where(hmc_state.i < wa_steps, max_treedepth[0], max_treedepth[1]),
466 )
--> 467 vv_state, energy, num_steps, accept_prob, diverging = _next(
468 hmc_state.adapt_state.step_size,
469 hmc_state.adapt_state.inverse_mass_matrix,
470 vv_state,
471 model_args,
472 model_kwargs,
473 rng_key_transition,
474 *hmc_length_args,
475 )
476 # not update adapt_state after warmup phase
477 adapt_state = cond(
478 hmc_state.i < wa_steps,
479 (hmc_state.i, accept_prob, vv_state, hmc_state.adapt_state),
(...)
482 identity,
483 )
File ~/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/infer/hmc.py:407, in hmc.<locals>._nuts_next(step_size, inverse_mass_matrix, vv_state, model_args, model_kwargs, rng_key, max_treedepth_current)
404 pe_fn = potential_fn_gen(*model_args, **model_kwargs)
405 _, vv_update = velocity_verlet(pe_fn, kinetic_fn, forward_mode_ad)
--> 407 binary_tree = build_tree(
408 vv_update,
409 kinetic_fn,
410 vv_state,
411 inverse_mass_matrix,
412 step_size,
413 rng_key,
414 max_delta_energy=max_delta_energy,
415 max_tree_depth=(max_treedepth_current, max(max_treedepth)),
416 )
417 accept_prob = binary_tree.sum_accept_probs / binary_tree.num_proposals
418 num_steps = binary_tree.num_proposals
File ~/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/infer/hmc_util.py:1177, in build_tree(verlet_update, kinetic_fn, verlet_state, inverse_mass_matrix, step_size, rng_key, max_delta_energy, max_tree_depth)
1174 return tree, key
1176 state = (tree, rng_key)
-> 1177 tree, _ = while_loop(_cond_fn, _body_fn, state)
1178 return tree
File ~/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/util.py:131, in while_loop(cond_fun, body_fun, init_val)
129 return val
130 else:
--> 131 return lax.while_loop(cond_fun, body_fun, init_val)
[... skipping hidden 13 frame]
File ~/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/infer/hmc_util.py:1161, in build_tree.<locals>._body_fn(state)
1159 key, direction_key, doubling_key = random.split(key, 3)
1160 going_right = random.bernoulli(direction_key)
-> 1161 tree = _double_tree(
1162 tree,
1163 verlet_update,
1164 kinetic_fn,
1165 inverse_mass_matrix,
1166 step_size,
1167 going_right,
1168 doubling_key,
1169 energy_current,
1170 max_delta_energy,
1171 r_ckpts,
1172 r_sum_ckpts,
1173 )
1174 return tree, key
File ~/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/infer/hmc_util.py:916, in _double_tree(current_tree, vv_update, kinetic_fn, inverse_mass_matrix, step_size, going_right, rng_key, energy_current, max_delta_energy, r_ckpts, r_sum_ckpts)
901 def _double_tree(
902 current_tree,
903 vv_update,
(...)
912 r_sum_ckpts,
913 ):
914 key, transition_key = random.split(rng_key)
--> 916 new_tree = _iterative_build_subtree(
917 current_tree,
918 vv_update,
919 kinetic_fn,
920 inverse_mass_matrix,
921 step_size,
922 going_right,
923 key,
924 energy_current,
925 max_delta_energy,
926 r_ckpts,
927 r_sum_ckpts,
928 )
930 return _combine_tree(
931 current_tree, new_tree, inverse_mass_matrix, going_right, transition_key, True
932 )
File ~/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/infer/hmc_util.py:1061, in _iterative_build_subtree(prototype_tree, vv_update, kinetic_fn, inverse_mass_matrix, step_size, going_right, rng_key, energy_current, max_delta_energy, r_ckpts, r_sum_ckpts)
1057 return new_tree, turning, r_ckpts, r_sum_ckpts, rng_key
1059 basetree = prototype_tree._replace(num_proposals=0)
-> 1061 tree, turning, _, _, _ = while_loop(
1062 _cond_fn, _body_fn, (basetree, False, r_ckpts, r_sum_ckpts, rng_key)
1063 )
1064 # update depth and turning condition
1065 return TreeInfo(
1066 tree.z_left,
1067 tree.r_left,
(...)
1082 tree.num_proposals,
1083 )
File ~/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/util.py:131, in while_loop(cond_fun, body_fun, init_val)
129 return val
130 else:
--> 131 return lax.while_loop(cond_fun, body_fun, init_val)
[... skipping hidden 13 frame]
File ~/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/infer/hmc_util.py:1006, in _iterative_build_subtree.<locals>._body_fn(state)
1004 # If we are going to the right, start from the right leaf of the current tree.
1005 z, r, z_grad = _get_leaf(current_tree, going_right)
-> 1006 new_leaf = _build_basetree(
1007 vv_update,
1008 kinetic_fn,
1009 z,
1010 r,
1011 z_grad,
1012 inverse_mass_matrix,
1013 step_size,
1014 going_right,
1015 energy_current,
1016 max_delta_energy,
1017 )
1018 new_tree = cond(
1019 current_tree.num_proposals == 0,
1020 new_leaf,
(...)
1029 lambda x: _combine_tree(*x, False),
1030 )
1032 leaf_idx = current_tree.num_proposals
File ~/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/infer/hmc_util.py:858, in _build_basetree(vv_update, kinetic_fn, z, r, z_grad, inverse_mass_matrix, step_size, going_right, energy_current, max_delta_energy)
845 def _build_basetree(
846 vv_update,
847 kinetic_fn,
(...)
855 max_delta_energy,
856 ):
857 step_size = jnp.where(going_right, step_size, -step_size)
--> 858 z_new, r_new, potential_energy_new, z_new_grad = vv_update(
859 step_size, inverse_mass_matrix, (z, r, energy_current, z_grad)
860 )
862 energy_new = potential_energy_new + kinetic_fn(inverse_mass_matrix, r_new)
863 delta_energy = energy_new - energy_current
File ~/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/infer/hmc_util.py:297, in velocity_verlet.<locals>.update_fn(step_size, inverse_mass_matrix, state)
295 r_grad = _kinetic_grad(kinetic_fn, inverse_mass_matrix, r)
296 z = tree_map(lambda z, r_grad: z + step_size * r_grad, z, r_grad) # z(n+1)
--> 297 potential_energy, z_grad = _value_and_grad(
298 potential_fn, z, forward_mode_differentiation
299 )
300 r = tree_map(
301 lambda r, z_grad: r - 0.5 * step_size * z_grad, r, z_grad
302 ) # r(n+1)
303 return IntegratorState(z, r, potential_energy, z_grad)
File ~/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/infer/hmc_util.py:246, in _value_and_grad(f, x, forward_mode_differentiation)
244 return f(x), jacfwd(f)(x)
245 else:
--> 246 return value_and_grad(f)(x)
[... skipping hidden 8 frame]
File ~/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/infer/util.py:248, in potential_energy(model, model_args, model_kwargs, params, enum)
244 substituted_model = substitute(
245 model, substitute_fn=partial(_unconstrain_reparam, params)
246 )
247 # no param is needed for log_density computation because we already substitute
--> 248 log_joint, model_trace = log_density_(
249 substituted_model, model_args, model_kwargs, {}
250 )
251 return -log_joint
File ~/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/infer/util.py:62, in log_density(model, model_args, model_kwargs, params)
50 """
51 (EXPERIMENTAL INTERFACE) Computes log of joint density for the model given
52 latent values ``params``.
(...)
59 :return: log of joint density and a corresponding model trace
60 """
61 model = substitute(model, data=params)
---> 62 model_trace = trace(model).get_trace(*model_args, **model_kwargs)
63 log_joint = jnp.zeros(())
64 for site in model_trace.values():
File ~/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/handlers.py:171, in trace.get_trace(self, *args, **kwargs)
163 def get_trace(self, *args, **kwargs):
164 """
165 Run the wrapped callable and return the recorded trace.
166
(...)
169 :return: `OrderedDict` containing the execution trace.
170 """
--> 171 self(*args, **kwargs)
172 return self.trace
File ~/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
File ~/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
[... skipping similar frames: Messenger.__call__ at line 105 (4 times)]
File ~/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
Input In [2], in posterior_model(num_samples, post_samples)
20 probs = np.ones(num_choices)/num_choices
21 #categorical distribution where probability of drawing a random sample is the same for all prior model samples
---> 22 beta_ix = numpyro.sample('beta_ix',dist.Categorical(probs=probs)).astype(np.int32)
24 intercept = post_samples['intercept'][beta_ix]
25 beta = post_samples['beta'][beta_ix]
File ~/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/primitives.py:219, in sample(name, fn, obs, rng_key, sample_shape, infer, obs_mask)
204 initial_msg = {
205 "type": "sample",
206 "name": name,
(...)
215 "infer": {} if infer is None else infer,
216 }
218 # ...and use apply_stack to send it to the Messengers
--> 219 msg = apply_stack(initial_msg)
220 return msg["value"]
File ~/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/primitives.py:47, in apply_stack(msg)
45 pointer = 0
46 for pointer, handler in enumerate(reversed(_PYRO_STACK)):
---> 47 handler.process_message(msg)
48 # When a Messenger sets the "stop" field of a message,
49 # it prevents any Messengers above it on the stack from being applied.
50 if msg.get("stop"):
File ~/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/handlers.py:720, in seed.process_message(self, msg)
717 if msg["value"] is not None:
718 # no need to create a new key when value is available
719 return
--> 720 self.rng_key, rng_key_sample = random.split(self.rng_key)
721 msg["kwargs"]["rng_key"] = rng_key_sample
File ~/.conda/envs/pmi/lib/python3.10/site-packages/jax/_src/random.py:194, in split(key, num)
183 """Splits a PRNG key into `num` new keys by adding a leading axis.
184
185 Args:
(...)
191 An array-like object of `num` new PRNG keys.
192 """
193 key, wrapped = _check_prng_key(key)
--> 194 return _return_prng_keys(wrapped, _split(key, num))
File ~/.conda/envs/pmi/lib/python3.10/site-packages/jax/_src/random.py:180, in _split(key, num)
177 def _split(key: KeyArray, num: int = 2) -> KeyArray:
178 # Alternative to split() to use within random samplers.
179 # TODO(frostig): remove and use split() once we always enable_custom_prng
--> 180 return key._split(num)
File ~/.conda/envs/pmi/lib/python3.10/site-packages/jax/_src/prng.py:203, in PRNGKeyArray._split(self, num)
202 def _split(self, num: int) -> 'PRNGKeyArray':
--> 203 return PRNGKeyArray(self.impl, self.impl.split(self._keys, num))
File ~/.conda/envs/pmi/lib/python3.10/site-packages/jax/_src/prng.py:474, in threefry_split(key, num)
473 def threefry_split(key: jnp.ndarray, num: int) -> jnp.ndarray:
--> 474 return _threefry_split(key, int(num))
[... skipping hidden 5 frame]
File ~/.conda/envs/pmi/lib/python3.10/site-packages/jax/interpreters/partial_eval.py:1351, in DynamicJaxprTracer._assert_live(self)
1349 def _assert_live(self) -> None:
1350 if not self._trace.main.jaxpr_stack: # type: ignore
-> 1351 raise core.escaped_tracer_error(self, None)
UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with shape (2,) and dtype uint32 to escape.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was body_fn at /home/nnisbet/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/infer/util.py:314 traced for while_loop.
------------------------------
The leaked intermediate value was created on line /home/nnisbet/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/handlers.py:716 (process_message).
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
/home/nnisbet/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/primitives.py:102 (__call__)
/local_scratch/pbs.5250180.pbs02/ipykernel_1291431/3692405016.py:25 (posterior_model)
/home/nnisbet/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/primitives.py:186 (sample)
/home/nnisbet/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/primitives.py:46 (apply_stack)
/home/nnisbet/.conda/envs/pmi/lib/python3.10/site-packages/numpyro/handlers.py:716 (process_message)
------------------------------
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError