I’m hoping to port over Inference Gym distributions without re-implementing them to avoid any possible differences.
I’m getting Tracer Conversion errors because there are numpy arrays that exist within the Inference Gym distributions
Is there a way to avoid this to port these over properly without having to reimplement them?
import jax
import numpyro
from numpyro import distributions as dist
from numpyro.infer import MCMC, NUTS
from inference_gym import using_jax as gym
class Banana(dist.Distribution):
arg_constraints = {"ndims": dist.constraints.positive_integer, "curvature": dist.constraints.real}
support = dist.constraints.real_vector
pytree_data_fields = ("ndims", "curvature")
def __init__(self, ndims, curvature):
self.ndims = ndims
self.curvature = curvature
self.gym_dist = gym.targets.Banana(ndims=ndims, curvature=curvature)
super().__init__(event_shape=(ndims,))
def sample(self, key, sample_shape=()):
return self.gym_dist.sample(seed=key, sample_shape=sample_shape)
def log_prob(self, value):
return self.gym_dist._unnormalized_log_prob(value)
samples = Banana(ndims=3, curvature=0.03).sample(jax.random.PRNGKey(0), sample_shape=(100,))
def model(X):
curvature = numpyro.sample("curvature", dist.Beta(1,30))
return numpyro.sample("obs", Banana(ndims=3, curvature=curvature), obs=X)
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=500, num_samples=1000)
mcmc.run(jax.random.PRNGKey(0), X=samples)
Here’s the full traceback
Traceback
File “/var/folders/1p/v01fvg3j1cz1hzv988ygj8m00000gn/T/ipykernel_28885/3778384098.py”, line 37, in
mcmc.run(jax.random.PRNGKey(0), X=samples)
File “/.venv/lib/python3.11/site-packages/numpyro/infer/mcmc.py”, line 702, in run
states_flat, last_state = partial_map_fn(map_args)
^^^^^^^^^^^^^^^^^^^^^^^^
File “/.venv/lib/python3.11/site-packages/numpyro/infer/mcmc.py”, line 465, in _single_chain_mcmc
new_init_state = self.sampler.init(
^^^^^^^^^^^^^^^^^^
File “/.venv/lib/python3.11/site-packages/numpyro/infer/hmc.py”, line 749, in init
init_params = self._init_state(
^^^^^^^^^^^^^^^^^
File “/.venv/lib/python3.11/site-packages/numpyro/infer/hmc.py”, line 693, in _init_state
) = initialize_model(
^^^^^^^^^^^^^^^^^
File “/.venv/lib/python3.11/site-packages/numpyro/infer/util.py”, line 713, in initialize_model
(init_params, pe, grad), is_valid = find_valid_initial_params(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/.venv/lib/python3.11/site-packages/numpyro/infer/util.py”, line 447, in find_valid_initial_params
(init_params, pe, z_grad), is_valid = _find_valid_params(
^^^^^^^^^^^^^^^^^^^
File “/.venv/lib/python3.11/site-packages/numpyro/infer/util.py”, line 433, in _find_valid_params
_, _, (init_params, pe, z_grad), is_valid = init_state = body_fn(init_state)
^^^^^^^^^^^^^^^^^^^
File “/.venv/lib/python3.11/site-packages/numpyro/infer/util.py”, line 417, in body_fn
pe, z_grad = value_and_grad(potential_fn)(params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/.venv/lib/python3.11/site-packages/jax/_src/traceback_util.py”, line 180, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^
File “/.venv/lib/python3.11/site-packages/jax/_src/api.py”, line 468, in value_and_grad_f
ans, vjp_py = _vjp(f_partial, *dyn_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/.venv/lib/python3.11/site-packages/jax/_src/api.py”, line 1975, in _vjp
out_primals, vjp = ad.vjp(flat_fun, primals_flat)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/.venv/lib/python3.11/site-packages/jax/_src/interpreters/ad.py”, line 252, in vjp
out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/.venv/lib/python3.11/site-packages/jax/_src/interpreters/ad.py”, line 237, in linearize
jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/.venv/lib/python3.11/site-packages/jax/_src/profiler.py”, line 333, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File “/.venv/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py”, line 574, in trace_to_jaxpr_nounits
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
^^^^^^^^^^^^^^^^^^^^^^^
File “/.venv/lib/python3.11/site-packages/jax/_src/linear_util.py”, line 192, in call_wrapped
return self.f_transformed(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/.venv/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py”, line 587, in trace_to_subjaxpr_nounits
out_tracers, jaxpr, out_consts, env = _trace_to_subjaxpr_nounits(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/.venv/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py”, line 616, in _trace_to_subjaxpr_nounits
ans = f(*in_args)
^^^^^^^^^^^
File “/.venv/lib/python3.11/site-packages/jax/_src/api_util.py”, line 72, in flatten_fun
ans = f(*py_args, **py_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File “/.venv/lib/python3.11/site-packages/jax/_src/interpreters/ad.py”, line 78, in jvpfun
out_primals, out_tangents = f(tag, primals, tangents)
^^^^^^^^^^^^^^^^^^^^^^^^^
File “/.venv/lib/python3.11/site-packages/jax/_src/interpreters/ad.py”, line 115, in jvp_subtrace
ans = f(*in_tracers)
^^^^^^^^^^^^^^
File “/.venv/lib/python3.11/site-packages/jax/_src/api_util.py”, line 88, in flatten_fun_nokwargs
ans = f(*py_args)
^^^^^^^^^^^
File “/.venv/lib/python3.11/site-packages/jax/_src/api_util.py”, line 292, in _argnums_partial
return fun(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File “/.venv/lib/python3.11/site-packages/numpyro/infer/util.py”, line 299, in potential_energy
log_joint, model_trace = log_density(
^^^^^^^^^^^^^
File “/.venv/lib/python3.11/site-packages/numpyro/infer/util.py”, line 70, in log_density
model_trace = trace(model).get_trace(*model_args, **model_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/.venv/lib/python3.11/site-packages/numpyro/handlers.py”, line 186, in get_trace
self(*args, **kwargs)
File “/.venv/lib/python3.11/site-packages/numpyro/primitives.py”, line 105, in call
return self.fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File “/.venv/lib/python3.11/site-packages/numpyro/primitives.py”, line 105, in call
return self.fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File “/.venv/lib/python3.11/site-packages/numpyro/primitives.py”, line 105, in call
return self.fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
[Previous line repeated 3 more times]
File “/var/folders/1p/v01fvg3j1cz1hzv988ygj8m00000gn/T/ipykernel_28885/3778384098.py”, line 31, in model
return numpyro.sample(“obs”, Banana(ndims=3, curvature=curvature), obs=X)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/.venv/lib/python3.11/site-packages/numpyro/distributions/distribution.py”, line 100, in call
return super().call(*args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/var/folders/1p/v01fvg3j1cz1hzv988ygj8m00000gn/T/ipykernel_28885/3778384098.py”, line 17, in init
self.gym_dist = gym.targets.Banana(ndims=ndims, curvature=curvature)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/.venv/lib/python3.11/site-packages/inference_gym/targets/banana.py”, line 116, in init
[10.] + [np.sqrt(1. + 2 * curvature2 * 10.**4)] +
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/.venv/lib/python3.11/site-packages/jax/_src/core.py”, line 692, in array
raise TracerArrayConversionError(self)