Mixture model example in Numpyro tutorial (GMM) does not run

–Question/problem
i am trying to run the example in Numpyro titled “Gaussian Mixture Model” (Gaussian Mixture Model — NumPyro documentation), however the model and associated call to MCMC returns an error.

The shapes of each site are as follows:
Trace Shapes:
Param Sites:
Sample Sites:
weights dist | 2
value | 2
scale dist |
value |
component plate 2 |
locs dist 2 |
value 2 |
data plate 5 |
assignment dist 5 |
value 5 |
y dist 5 |
value 5 |

–Example of code to run–
data = np.array([0.0, 1.0, 10.0, 11.0, 12.0])

def model(data,K=2):
    weights = numpyro.sample("weights", dist.Dirichlet(0.5 * jnp.ones(K)))
    scale = numpyro.sample("scale", dist.LogNormal(0.0, 2.0))

    with numpyro.plate("component",2):
        locs = numpyro.sample("locs", dist.Normal(0.0, 10.0 ))

    with numpyro.plate("data", len(data)):
        assignment = numpyro.sample("assignment", dist.Categorical(weights), infer={'enumerate': 'parallel'} )
        numpyro.sample("y", dist.Normal(locs[assignment], scale), obs=data)

with numpyro.handlers.seed(rng_seed=1):
    trace = numpyro.handlers.trace(model).get_trace(data=data)
print(numpyro.util.format_shapes(trace))
    
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=50, num_samples=250)
mcmc.run(PRNGKey(2), data)
mcmc.print_summary()
posterior_samples = mcmc.get_samples()

The error message that is returned reads:

---------------------------------------------------------------------------
ImportError                               Traceback (most recent call last)
Cell In[10], line 41
     39 kernel = NUTS(model)
     40 mcmc = MCMC(kernel, num_warmup=50, num_samples=250)
---> 41 mcmc.run(PRNGKey(2), data)
     42 mcmc.print_summary()
     43 posterior_samples = mcmc.get_samples()

File /usr/local/lib/python3.10/site-packages/numpyro/infer/mcmc.py:628, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs)
    626 map_args = (rng_key, init_state, init_params)
    627 if self.num_chains == 1:
--> 628     states_flat, last_state = partial_map_fn(map_args)
    629     states = tree_map(lambda x: x[jnp.newaxis, ...], states_flat)
    630 else:

File /usr/local/lib/python3.10/site-packages/numpyro/infer/mcmc.py:410, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields)
    408 # Check if _sample_fn is None, then we need to initialize the sampler.
    409 if init_state is None or (getattr(self.sampler, "_sample_fn", None) is None):
--> 410     new_init_state = self.sampler.init(
    411         rng_key,
    412         self.num_warmup,
    413         init_params,
    414         model_args=args,
    415         model_kwargs=kwargs,
    416     )
    417     init_state = new_init_state if init_state is None else init_state
    418 sample_fn, postprocess_fn = self._get_cached_fns()

File /usr/local/lib/python3.10/site-packages/numpyro/infer/hmc.py:713, in HMC.init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
    708 # vectorized
    709 else:
    710     rng_key, rng_key_init_model = jnp.swapaxes(
    711         vmap(random.split)(rng_key), 0, 1
    712     )
--> 713 init_params = self._init_state(
    714     rng_key_init_model, model_args, model_kwargs, init_params
    715 )
    716 if self._potential_fn and init_params is None:
    717     raise ValueError(
    718         "Valid value of `init_params` must be provided with" " `potential_fn`."
    719     )

File /usr/local/lib/python3.10/site-packages/numpyro/infer/hmc.py:657, in HMC._init_state(self, rng_key, model_args, model_kwargs, init_params)
    650 def _init_state(self, rng_key, model_args, model_kwargs, init_params):
    651     if self._model is not None:
    652         (
    653             new_init_params,
    654             potential_fn,
    655             postprocess_fn,
    656             model_trace,
--> 657         ) = initialize_model(
    658             rng_key,
    659             self._model,
    660             dynamic_args=True,
    661             init_strategy=self._init_strategy,
    662             model_args=model_args,
    663             model_kwargs=model_kwargs,
    664             forward_mode_differentiation=self._forward_mode_differentiation,
    665         )
    666         if init_params is None:
    667             init_params = new_init_params

File /usr/local/lib/python3.10/site-packages/numpyro/infer/util.py:700, in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad)
    698     init_strategy = _init_to_unconstrained_value(values=unconstrained_values)
    699 prototype_params = transform_fn(inv_transforms, constrained_values, invert=True)
--> 700 (init_params, pe, grad), is_valid = find_valid_initial_params(
    701     rng_key,
    702     substitute(
    703         model,
    704         data={
    705             k: site["value"]
    706             for k, site in model_trace.items()
    707             if site["type"] in ["plate"]
    708         },
    709     ),
    710     init_strategy=init_strategy,
    711     enum=has_enumerate_support,
    712     model_args=model_args,
    713     model_kwargs=model_kwargs,
    714     prototype_params=prototype_params,
    715     forward_mode_differentiation=forward_mode_differentiation,
    716     validate_grad=validate_grad,
    717 )
    719 if not_jax_tracer(is_valid):
    720     if device_get(~jnp.all(is_valid)):

File /usr/local/lib/python3.10/site-packages/numpyro/infer/util.py:437, in find_valid_initial_params(rng_key, model, init_strategy, enum, model_args, model_kwargs, prototype_params, forward_mode_differentiation, validate_grad)
    435 # Handle possible vectorization
    436 if rng_key.ndim == 1:
--> 437     (init_params, pe, z_grad), is_valid = _find_valid_params(
    438         rng_key, exit_early=True
    439     )
    440 else:
    441     (init_params, pe, z_grad), is_valid = lax.map(_find_valid_params, rng_key)

File /usr/local/lib/python3.10/site-packages/numpyro/infer/util.py:423, in find_valid_initial_params.<locals>._find_valid_params(rng_key, exit_early)
    419 init_state = (0, rng_key, (prototype_params, 0.0, prototype_grads), False)
    420 if exit_early and not_jax_tracer(rng_key):
    421     # Early return if valid params found. This is only helpful for single chain,
    422     # where we can avoid compiling body_fn in while_loop.
--> 423     _, _, (init_params, pe, z_grad), is_valid = init_state = body_fn(init_state)
    424     if not_jax_tracer(is_valid):
    425         if device_get(is_valid):

File /usr/local/lib/python3.10/site-packages/numpyro/infer/util.py:407, in find_valid_initial_params.<locals>.body_fn(state)
    405     z_grad = jacfwd(potential_fn)(params)
    406 else:
--> 407     pe, z_grad = value_and_grad(potential_fn)(params)
    408 z_grad_flat = ravel_pytree(z_grad)[0]
    409 is_valid = jnp.isfinite(pe) & jnp.all(jnp.isfinite(z_grad_flat))

    [... skipping hidden 8 frame]

File /usr/local/lib/python3.10/site-packages/numpyro/infer/util.py:289, in potential_energy(model, model_args, model_kwargs, params, enum)
    285 substituted_model = substitute(
    286     model, substitute_fn=partial(_unconstrain_reparam, params)
    287 )
    288 # no param is needed for log_density computation because we already substitute
--> 289 log_joint, model_trace = log_density_(
    290     substituted_model, model_args, model_kwargs, {}
    291 )
    292 return -log_joint

File /usr/local/lib/python3.10/site-packages/numpyro/contrib/funsor/infer_util.py:318, in log_density(model, model_args, model_kwargs, params)
    297 def log_density(model, model_args, model_kwargs, params):
    298     """
    299     Similar to :func:`numpyro.infer.util.log_density` but works for models
    300     with discrete latent variables. Internally, this uses :mod:`funsor`
   (...)
    316     :return: log of joint density and a corresponding model trace
    317     """
--> 318     result, model_trace, _ = _enum_log_density(
    319         model,
    320         model_args,
    321         model_kwargs,
    322         params,
    323         funsor.ops.logaddexp,
    324         funsor.ops.add,
    325     )
    326     return result.data, model_trace

File /usr/local/lib/python3.10/site-packages/numpyro/contrib/funsor/infer_util.py:201, in _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op)
    199 model = substitute(model, data=params)
    200 with plate_to_enum_plate():
--> 201     model_trace = packed_trace(model).get_trace(*model_args, **model_kwargs)
    202 log_factors = []
    203 time_to_factors = defaultdict(list)  # log prob factors

File /usr/local/lib/python3.10/site-packages/numpyro/handlers.py:171, in trace.get_trace(self, *args, **kwargs)
    163 def get_trace(self, *args, **kwargs):
    164     """
    165     Run the wrapped callable and return the recorded trace.
    166 
   (...)
    169     :return: `OrderedDict` containing the execution trace.
    170     """
--> 171     self(*args, **kwargs)
    172     return self.trace

File /usr/local/lib/python3.10/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
    103     return self
    104 with self:
--> 105     return self.fn(*args, **kwargs)

File /usr/local/lib/python3.10/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
    103     return self
    104 with self:
--> 105     return self.fn(*args, **kwargs)

    [... skipping similar frames: Messenger.__call__ at line 105 (4 times)]

File /usr/local/lib/python3.10/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
    103     return self
    104 with self:
--> 105     return self.fn(*args, **kwargs)

Cell In[10], line 28, in model(data, K)
     25 weights = numpyro.sample("weights", dist.Dirichlet(0.5 * jnp.ones(K)))
     26 scale = numpyro.sample("scale", dist.LogNormal(0.0, 2.0))
---> 28 with numpyro.plate("component",2):
     29     locs = numpyro.sample("locs", dist.Normal(0.0, 10.0 ))
     31 with numpyro.plate("data", len(data)):

File /usr/local/lib/python3.10/site-packages/numpyro/contrib/funsor/infer_util.py:39, in plate_to_enum_plate.<locals>.<lambda>(cls, *args, **kwargs)
     24 """
     25 A context manager to replace `numpyro.plate` statement by a funsor-based
     26 :class:`~numpyro.contrib.funsor.enum_messenger.plate`.
   (...)
     36 
     37 """
     38 try:
---> 39     numpyro.plate.__new__ = lambda cls, *args, **kwargs: enum_plate(*args, **kwargs)
     40     yield
     41 finally:

File /usr/local/lib/python3.10/site-packages/numpyro/contrib/funsor/enum_messenger.py:491, in plate.__init__(self, name, size, subsample_size, dim)
    487 self.dim, indices = OrigPlateMessenger._subsample(
    488     self.name, self.size, subsample_size, dim
    489 )
    490 self.subsample_size = indices.shape[0]
--> 491 self._indices = funsor.Tensor(
    492     indices,
    493     OrderedDict([(self.name, funsor.Bint[self.subsample_size])]),
    494     self.subsample_size,
    495 )
    496 super(plate, self).__init__(None)

File /usr/local/lib/python3.10/site-packages/funsor/tensor.py:112, in TensorMeta.__call__(cls, data, inputs, dtype)
    109 if isinstance(data, np.generic):
    110     data = data.__array__()
--> 112 return super(TensorMeta, cls).__call__(data, inputs, dtype)

File /usr/local/lib/python3.10/site-packages/funsor/terms.py:211, in FunsorMeta.__call__(cls, *args, **kwargs)
    208     assert not kwargs, kwargs
    209     args = tuple(args)
--> 211 return interpret(cls, *args)

File /usr/local/lib/python3.10/site-packages/funsor/interpretations.py:196, in PrioritizedInterpretation.interpret(self, cls, *args)
    194 def interpret(self, cls, *args):
    195     for s in self._subinterpretations:
--> 196         result = s.interpret(cls, *args)
    197         if result is not None:
    198             return result

File /usr/local/lib/python3.10/site-packages/funsor/terms.py:137, in reflect(cls, *args, **kwargs)
    134     assert len(new_args) == len(cls._ast_fields)
    135     _, args = args, new_args
--> 137 cache_key = reflect.make_hash_key(cls, *args)
    138 if cache_key in cls._cons_cache:
    139     return cls._cons_cache[cache_key]

File /usr/local/lib/python3.10/site-packages/funsor/interpretations.py:58, in Interpretation.make_hash_key(cls, *args)
     55 backend = get_backend()
     56 if backend == "jax":
     57     # JAX DeviceArray has .__hash__ method but raise the unhashable error there.
---> 58     from jax.interpreters.xla import DeviceArray
     60     return tuple(
     61         id(arg)
     62         if isinstance(arg, DeviceArray) or not isinstance(arg, Hashable)
     63         else arg
     64         for arg in args
     65     )
     66 if backend == "torch":
     67     # Avoid "ImportError: sys.meta_path is None" on shutdown.

ImportError: cannot import name 'DeviceArray' from 'jax.interpreters.xla' (/usr/local/lib/python3.10/site-packages/jax/interpreters/xla.py)

could you update funsor? I think the issue is fixed upstream.

This worked! Thanks for the help.
Now that this works i can pose my more complicated implementation of a mixture model.
best
tom