The error message that is returned reads:
---------------------------------------------------------------------------
ImportError Traceback (most recent call last)
Cell In[10], line 41
39 kernel = NUTS(model)
40 mcmc = MCMC(kernel, num_warmup=50, num_samples=250)
---> 41 mcmc.run(PRNGKey(2), data)
42 mcmc.print_summary()
43 posterior_samples = mcmc.get_samples()
File /usr/local/lib/python3.10/site-packages/numpyro/infer/mcmc.py:628, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs)
626 map_args = (rng_key, init_state, init_params)
627 if self.num_chains == 1:
--> 628 states_flat, last_state = partial_map_fn(map_args)
629 states = tree_map(lambda x: x[jnp.newaxis, ...], states_flat)
630 else:
File /usr/local/lib/python3.10/site-packages/numpyro/infer/mcmc.py:410, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields)
408 # Check if _sample_fn is None, then we need to initialize the sampler.
409 if init_state is None or (getattr(self.sampler, "_sample_fn", None) is None):
--> 410 new_init_state = self.sampler.init(
411 rng_key,
412 self.num_warmup,
413 init_params,
414 model_args=args,
415 model_kwargs=kwargs,
416 )
417 init_state = new_init_state if init_state is None else init_state
418 sample_fn, postprocess_fn = self._get_cached_fns()
File /usr/local/lib/python3.10/site-packages/numpyro/infer/hmc.py:713, in HMC.init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
708 # vectorized
709 else:
710 rng_key, rng_key_init_model = jnp.swapaxes(
711 vmap(random.split)(rng_key), 0, 1
712 )
--> 713 init_params = self._init_state(
714 rng_key_init_model, model_args, model_kwargs, init_params
715 )
716 if self._potential_fn and init_params is None:
717 raise ValueError(
718 "Valid value of `init_params` must be provided with" " `potential_fn`."
719 )
File /usr/local/lib/python3.10/site-packages/numpyro/infer/hmc.py:657, in HMC._init_state(self, rng_key, model_args, model_kwargs, init_params)
650 def _init_state(self, rng_key, model_args, model_kwargs, init_params):
651 if self._model is not None:
652 (
653 new_init_params,
654 potential_fn,
655 postprocess_fn,
656 model_trace,
--> 657 ) = initialize_model(
658 rng_key,
659 self._model,
660 dynamic_args=True,
661 init_strategy=self._init_strategy,
662 model_args=model_args,
663 model_kwargs=model_kwargs,
664 forward_mode_differentiation=self._forward_mode_differentiation,
665 )
666 if init_params is None:
667 init_params = new_init_params
File /usr/local/lib/python3.10/site-packages/numpyro/infer/util.py:700, in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad)
698 init_strategy = _init_to_unconstrained_value(values=unconstrained_values)
699 prototype_params = transform_fn(inv_transforms, constrained_values, invert=True)
--> 700 (init_params, pe, grad), is_valid = find_valid_initial_params(
701 rng_key,
702 substitute(
703 model,
704 data={
705 k: site["value"]
706 for k, site in model_trace.items()
707 if site["type"] in ["plate"]
708 },
709 ),
710 init_strategy=init_strategy,
711 enum=has_enumerate_support,
712 model_args=model_args,
713 model_kwargs=model_kwargs,
714 prototype_params=prototype_params,
715 forward_mode_differentiation=forward_mode_differentiation,
716 validate_grad=validate_grad,
717 )
719 if not_jax_tracer(is_valid):
720 if device_get(~jnp.all(is_valid)):
File /usr/local/lib/python3.10/site-packages/numpyro/infer/util.py:437, in find_valid_initial_params(rng_key, model, init_strategy, enum, model_args, model_kwargs, prototype_params, forward_mode_differentiation, validate_grad)
435 # Handle possible vectorization
436 if rng_key.ndim == 1:
--> 437 (init_params, pe, z_grad), is_valid = _find_valid_params(
438 rng_key, exit_early=True
439 )
440 else:
441 (init_params, pe, z_grad), is_valid = lax.map(_find_valid_params, rng_key)
File /usr/local/lib/python3.10/site-packages/numpyro/infer/util.py:423, in find_valid_initial_params.<locals>._find_valid_params(rng_key, exit_early)
419 init_state = (0, rng_key, (prototype_params, 0.0, prototype_grads), False)
420 if exit_early and not_jax_tracer(rng_key):
421 # Early return if valid params found. This is only helpful for single chain,
422 # where we can avoid compiling body_fn in while_loop.
--> 423 _, _, (init_params, pe, z_grad), is_valid = init_state = body_fn(init_state)
424 if not_jax_tracer(is_valid):
425 if device_get(is_valid):
File /usr/local/lib/python3.10/site-packages/numpyro/infer/util.py:407, in find_valid_initial_params.<locals>.body_fn(state)
405 z_grad = jacfwd(potential_fn)(params)
406 else:
--> 407 pe, z_grad = value_and_grad(potential_fn)(params)
408 z_grad_flat = ravel_pytree(z_grad)[0]
409 is_valid = jnp.isfinite(pe) & jnp.all(jnp.isfinite(z_grad_flat))
[... skipping hidden 8 frame]
File /usr/local/lib/python3.10/site-packages/numpyro/infer/util.py:289, in potential_energy(model, model_args, model_kwargs, params, enum)
285 substituted_model = substitute(
286 model, substitute_fn=partial(_unconstrain_reparam, params)
287 )
288 # no param is needed for log_density computation because we already substitute
--> 289 log_joint, model_trace = log_density_(
290 substituted_model, model_args, model_kwargs, {}
291 )
292 return -log_joint
File /usr/local/lib/python3.10/site-packages/numpyro/contrib/funsor/infer_util.py:318, in log_density(model, model_args, model_kwargs, params)
297 def log_density(model, model_args, model_kwargs, params):
298 """
299 Similar to :func:`numpyro.infer.util.log_density` but works for models
300 with discrete latent variables. Internally, this uses :mod:`funsor`
(...)
316 :return: log of joint density and a corresponding model trace
317 """
--> 318 result, model_trace, _ = _enum_log_density(
319 model,
320 model_args,
321 model_kwargs,
322 params,
323 funsor.ops.logaddexp,
324 funsor.ops.add,
325 )
326 return result.data, model_trace
File /usr/local/lib/python3.10/site-packages/numpyro/contrib/funsor/infer_util.py:201, in _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op)
199 model = substitute(model, data=params)
200 with plate_to_enum_plate():
--> 201 model_trace = packed_trace(model).get_trace(*model_args, **model_kwargs)
202 log_factors = []
203 time_to_factors = defaultdict(list) # log prob factors
File /usr/local/lib/python3.10/site-packages/numpyro/handlers.py:171, in trace.get_trace(self, *args, **kwargs)
163 def get_trace(self, *args, **kwargs):
164 """
165 Run the wrapped callable and return the recorded trace.
166
(...)
169 :return: `OrderedDict` containing the execution trace.
170 """
--> 171 self(*args, **kwargs)
172 return self.trace
File /usr/local/lib/python3.10/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
File /usr/local/lib/python3.10/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
[... skipping similar frames: Messenger.__call__ at line 105 (4 times)]
File /usr/local/lib/python3.10/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
Cell In[10], line 28, in model(data, K)
25 weights = numpyro.sample("weights", dist.Dirichlet(0.5 * jnp.ones(K)))
26 scale = numpyro.sample("scale", dist.LogNormal(0.0, 2.0))
---> 28 with numpyro.plate("component",2):
29 locs = numpyro.sample("locs", dist.Normal(0.0, 10.0 ))
31 with numpyro.plate("data", len(data)):
File /usr/local/lib/python3.10/site-packages/numpyro/contrib/funsor/infer_util.py:39, in plate_to_enum_plate.<locals>.<lambda>(cls, *args, **kwargs)
24 """
25 A context manager to replace `numpyro.plate` statement by a funsor-based
26 :class:`~numpyro.contrib.funsor.enum_messenger.plate`.
(...)
36
37 """
38 try:
---> 39 numpyro.plate.__new__ = lambda cls, *args, **kwargs: enum_plate(*args, **kwargs)
40 yield
41 finally:
File /usr/local/lib/python3.10/site-packages/numpyro/contrib/funsor/enum_messenger.py:491, in plate.__init__(self, name, size, subsample_size, dim)
487 self.dim, indices = OrigPlateMessenger._subsample(
488 self.name, self.size, subsample_size, dim
489 )
490 self.subsample_size = indices.shape[0]
--> 491 self._indices = funsor.Tensor(
492 indices,
493 OrderedDict([(self.name, funsor.Bint[self.subsample_size])]),
494 self.subsample_size,
495 )
496 super(plate, self).__init__(None)
File /usr/local/lib/python3.10/site-packages/funsor/tensor.py:112, in TensorMeta.__call__(cls, data, inputs, dtype)
109 if isinstance(data, np.generic):
110 data = data.__array__()
--> 112 return super(TensorMeta, cls).__call__(data, inputs, dtype)
File /usr/local/lib/python3.10/site-packages/funsor/terms.py:211, in FunsorMeta.__call__(cls, *args, **kwargs)
208 assert not kwargs, kwargs
209 args = tuple(args)
--> 211 return interpret(cls, *args)
File /usr/local/lib/python3.10/site-packages/funsor/interpretations.py:196, in PrioritizedInterpretation.interpret(self, cls, *args)
194 def interpret(self, cls, *args):
195 for s in self._subinterpretations:
--> 196 result = s.interpret(cls, *args)
197 if result is not None:
198 return result
File /usr/local/lib/python3.10/site-packages/funsor/terms.py:137, in reflect(cls, *args, **kwargs)
134 assert len(new_args) == len(cls._ast_fields)
135 _, args = args, new_args
--> 137 cache_key = reflect.make_hash_key(cls, *args)
138 if cache_key in cls._cons_cache:
139 return cls._cons_cache[cache_key]
File /usr/local/lib/python3.10/site-packages/funsor/interpretations.py:58, in Interpretation.make_hash_key(cls, *args)
55 backend = get_backend()
56 if backend == "jax":
57 # JAX DeviceArray has .__hash__ method but raise the unhashable error there.
---> 58 from jax.interpreters.xla import DeviceArray
60 return tuple(
61 id(arg)
62 if isinstance(arg, DeviceArray) or not isinstance(arg, Hashable)
63 else arg
64 for arg in args
65 )
66 if backend == "torch":
67 # Avoid "ImportError: sys.meta_path is None" on shutdown.
ImportError: cannot import name 'DeviceArray' from 'jax.interpreters.xla' (/usr/local/lib/python3.10/site-packages/jax/interpreters/xla.py)