Weird Bernoullii distribution shape

Hi, I’m using Numpyro version 0.9.0 to reimplement Tim Salimans’s solution for the challenge Don’t Overfit on Kaggle for learning purpose.

I’m having trouble with the implementation. The jax.numpy doesn’t allow me to multiple the to_include variable with X due to shape mismatched.

Here is my implementation:

from jax import random
import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

def tim_dont_overfit_model(X, t):
    """
    X: a numpy array of shape [n, 200] contains explanatory variables.
    t: a numpy array of shape [n,] contains dependent variables.
    """
    _, n_variables = X.shape

    with numpyro.plate('vars', n_variables):
        # Each of these variables have a coefficient associated with them.
        coeffs = numpyro.sample('coefficient', dist.Uniform(0, 1))

        # In those 200 variables, there are roughly half of them
        # are used to create the target variable.
        to_include = numpyro.sample('to_include', dist.Bernoulli(0.5))

    # Y = coeffs * X
    print(n_variables, to_include.shape, X.shape, coeffs.shape)
    Y = jnp.dot(to_include * X, coeffs)

    # Z = Y - mean(Y)
    Z = jnp.where(Y - jnp.mean(Y) < 0, 1., 0.)

    with numpyro.plate('obs', len(t)):
        numpyro.sample('t', dist.Bernoulli(Z), obs=t)

# Start from this source of randomness. We will split keys for subsequent operations.
rng_key = random.PRNGKey(0)

# Run NUTS.
kernel = NUTS(tim_dont_overfit_model)
num_samples = 50000
mcmc = MCMC(kernel, num_warmup=1000, num_samples=num_samples)
mcmc.run(
    rng_key,
    X=np.random.rand(250, 200),
    t=np.random.rand(250),
)
mcmc.print_summary()

And here is the output with errors:

200 (200,) (250, 200) (200,)
200 (2, 1) (250, 200) (200,)

/tmp/ipykernel_23613/978136461.py:8: FutureWarning: Some algorithms will automatically enumerate the discrete latent site to_include of your model. In the future, enumerated sites need to be marked with `infer={'enumerate': 'parallel'}`.
  mcmc.run(

---------------------------------------------------------------------------
FilteredStackTrace                        Traceback (most recent call last)
Input In [12], in <module>
      7 mcmc = MCMC(kernel, num_warmup=1000, num_samples=num_samples)
----> 8 mcmc.run(
      9     rng_key,
     10     X=np.random.rand(250, 200),
     11     t=np.random.rand(250),
     12 )
     13 mcmc.print_summary()

File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/infer/mcmc.py:599, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs)
    598 if self.num_chains == 1:
--> 599     states_flat, last_state = partial_map_fn(map_args)
    600     states = tree_map(lambda x: x[jnp.newaxis, ...], states_flat)

File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/infer/mcmc.py:387, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields)
    386 if init_state is None:
--> 387     init_state = self.sampler.init(
    388         rng_key,
    389         self.num_warmup,
    390         init_params,
    391         model_args=args,
    392         model_kwargs=kwargs,
    393     )
    394 sample_fn, postprocess_fn = self._get_cached_fns()

File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/infer/hmc.py:696, in HMC.init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
    693     rng_key, rng_key_init_model = jnp.swapaxes(
    694         vmap(random.split)(rng_key), 0, 1
    695     )
--> 696 init_params = self._init_state(
    697     rng_key_init_model, model_args, model_kwargs, init_params
    698 )
    699 if self._potential_fn and init_params is None:

File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/infer/hmc.py:642, in HMC._init_state(self, rng_key, model_args, model_kwargs, init_params)
    641 if self._model is not None:
--> 642     init_params, potential_fn, postprocess_fn, model_trace = initialize_model(
    643         rng_key,
    644         self._model,
    645         dynamic_args=True,
    646         init_strategy=self._init_strategy,
    647         model_args=model_args,
    648         model_kwargs=model_kwargs,
    649         forward_mode_differentiation=self._forward_mode_differentiation,
    650     )
    651     if self._init_fn is None:

File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/infer/util.py:654, in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad)
    653 prototype_params = transform_fn(inv_transforms, constrained_values, invert=True)
--> 654 (init_params, pe, grad), is_valid = find_valid_initial_params(
    655     rng_key,
    656     substitute(
    657         model,
    658         data={
    659             k: site["value"]
    660             for k, site in model_trace.items()
    661             if site["type"] in ["plate"]
    662         },
    663     ),
    664     init_strategy=init_strategy,
    665     enum=has_enumerate_support,
    666     model_args=model_args,
    667     model_kwargs=model_kwargs,
    668     prototype_params=prototype_params,
    669     forward_mode_differentiation=forward_mode_differentiation,
    670     validate_grad=validate_grad,
    671 )
    673 if not_jax_tracer(is_valid):

File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/infer/util.py:395, in find_valid_initial_params(rng_key, model, init_strategy, enum, model_args, model_kwargs, prototype_params, forward_mode_differentiation, validate_grad)
    394 if rng_key.ndim == 1:
--> 395     (init_params, pe, z_grad), is_valid = _find_valid_params(
    396         rng_key, exit_early=True
    397     )
    398 else:

File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/infer/util.py:381, in find_valid_initial_params.<locals>._find_valid_params(rng_key, exit_early)
    378 if exit_early and not_jax_tracer(rng_key):
    379     # Early return if valid params found. This is only helpful for single chain,
    380     # where we can avoid compiling body_fn in while_loop.
--> 381     _, _, (init_params, pe, z_grad), is_valid = init_state = body_fn(init_state)
    382     if not_jax_tracer(is_valid):

File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/infer/util.py:366, in find_valid_initial_params.<locals>.body_fn(state)
    365 else:
--> 366     pe, z_grad = value_and_grad(potential_fn)(params)
    367 z_grad_flat = ravel_pytree(z_grad)[0]

File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/infer/util.py:248, in potential_energy(model, model_args, model_kwargs, params, enum)
    247 # no param is needed for log_density computation because we already substitute
--> 248 log_joint, model_trace = log_density_(
    249     substituted_model, model_args, model_kwargs, {}
    250 )
    251 return -log_joint

File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/contrib/funsor/infer_util.py:270, in log_density(model, model_args, model_kwargs, params)
    250 """
    251 Similar to :func:`numpyro.infer.util.log_density` but works for models
    252 with discrete latent variables. Internally, this uses :mod:`funsor`
   (...)
    268 :return: log of joint density and a corresponding model trace
    269 """
--> 270 result, model_trace, _ = _enum_log_density(
    271     model, model_args, model_kwargs, params, funsor.ops.logaddexp, funsor.ops.add
    272 )
    273 return result.data, model_trace

File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/contrib/funsor/infer_util.py:159, in _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op)
    158 with plate_to_enum_plate():
--> 159     model_trace = packed_trace(model).get_trace(*model_args, **model_kwargs)
    160 log_factors = []

File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/handlers.py:171, in trace.get_trace(self, *args, **kwargs)
    164 """
    165 Run the wrapped callable and return the recorded trace.
    166 
   (...)
    169 :return: `OrderedDict` containing the execution trace.
    170 """
--> 171 self(*args, **kwargs)
    172 return self.trace

File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/primitives.py:87, in Messenger.__call__(self, *args, **kwargs)
     86 with self:
---> 87     return self.fn(*args, **kwargs)

File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/primitives.py:87, in Messenger.__call__(self, *args, **kwargs)
     86 with self:
---> 87     return self.fn(*args, **kwargs)

    [... skipping similar frames: Messenger.__call__ at line 87 (4 times)]

File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/primitives.py:87, in Messenger.__call__(self, *args, **kwargs)
     86 with self:
---> 87     return self.fn(*args, **kwargs)

Input In [11], in tim_dont_overfit_model(X, t)
     17 print(n_variables, to_include.shape, X.shape, coeffs.shape)
---> 18 Y = jnp.dot(to_include * X, coeffs)
     20 # Z = Y - mean(Y)

File ~/.conda/envs/iub1/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:5333, in _defer_to_unrecognized_arg.<locals>.deferring_binary_op(self, other)
   5332   return NotImplemented
-> 5333 return binary_op(self, other)

File ~/.conda/envs/iub1/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:396, in _maybe_bool_binop.<locals>.fn(x1, x2)
    395 x1, x2 = _promote_args(numpy_fn.__name__, x1, x2)
--> 396 return lax_fn(x1, x2) if x1.dtype != bool_ else bool_lax_fn(x1, x2)

FilteredStackTrace: TypeError: mul got incompatible shapes for broadcasting: (2, 1), (250, 200).

The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

TypeError                                 Traceback (most recent call last)
Input In [12], in <module>
      6 num_samples = 50000
      7 mcmc = MCMC(kernel, num_warmup=1000, num_samples=num_samples)
----> 8 mcmc.run(
      9     rng_key,
     10     X=np.random.rand(250, 200),
     11     t=np.random.rand(250),
     12 )
     13 mcmc.print_summary()

File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/infer/mcmc.py:599, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs)
    597 map_args = (rng_key, init_state, init_params)
    598 if self.num_chains == 1:
--> 599     states_flat, last_state = partial_map_fn(map_args)
    600     states = tree_map(lambda x: x[jnp.newaxis, ...], states_flat)
    601 else:

File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/infer/mcmc.py:387, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields)
    385 rng_key, init_state, init_params = init
    386 if init_state is None:
--> 387     init_state = self.sampler.init(
    388         rng_key,
    389         self.num_warmup,
    390         init_params,
    391         model_args=args,
    392         model_kwargs=kwargs,
    393     )
    394 sample_fn, postprocess_fn = self._get_cached_fns()
    395 diagnostics = (
    396     lambda x: self.sampler.get_diagnostics_str(x[0])
    397     if rng_key.ndim == 1
    398     else ""
    399 )  # noqa: E731

File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/infer/hmc.py:696, in HMC.init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
    691 # vectorized
    692 else:
    693     rng_key, rng_key_init_model = jnp.swapaxes(
    694         vmap(random.split)(rng_key), 0, 1
    695     )
--> 696 init_params = self._init_state(
    697     rng_key_init_model, model_args, model_kwargs, init_params
    698 )
    699 if self._potential_fn and init_params is None:
    700     raise ValueError(
    701         "Valid value of `init_params` must be provided with" " `potential_fn`."
    702     )

File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/infer/hmc.py:642, in HMC._init_state(self, rng_key, model_args, model_kwargs, init_params)
    640 def _init_state(self, rng_key, model_args, model_kwargs, init_params):
    641     if self._model is not None:
--> 642         init_params, potential_fn, postprocess_fn, model_trace = initialize_model(
    643             rng_key,
    644             self._model,
    645             dynamic_args=True,
    646             init_strategy=self._init_strategy,
    647             model_args=model_args,
    648             model_kwargs=model_kwargs,
    649             forward_mode_differentiation=self._forward_mode_differentiation,
    650         )
    651         if self._init_fn is None:
    652             self._init_fn, self._sample_fn = hmc(
    653                 potential_fn_gen=potential_fn,
    654                 kinetic_fn=self._kinetic_fn,
    655                 algo=self._algo,
    656             )

File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/infer/util.py:654, in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad)
    652     init_strategy = _init_to_unconstrained_value(values=unconstrained_values)
    653 prototype_params = transform_fn(inv_transforms, constrained_values, invert=True)
--> 654 (init_params, pe, grad), is_valid = find_valid_initial_params(
    655     rng_key,
    656     substitute(
    657         model,
    658         data={
    659             k: site["value"]
    660             for k, site in model_trace.items()
    661             if site["type"] in ["plate"]
    662         },
    663     ),
    664     init_strategy=init_strategy,
    665     enum=has_enumerate_support,
    666     model_args=model_args,
    667     model_kwargs=model_kwargs,
    668     prototype_params=prototype_params,
    669     forward_mode_differentiation=forward_mode_differentiation,
    670     validate_grad=validate_grad,
    671 )
    673 if not_jax_tracer(is_valid):
    674     if device_get(~jnp.all(is_valid)):

File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/infer/util.py:395, in find_valid_initial_params(rng_key, model, init_strategy, enum, model_args, model_kwargs, prototype_params, forward_mode_differentiation, validate_grad)
    393 # Handle possible vectorization
    394 if rng_key.ndim == 1:
--> 395     (init_params, pe, z_grad), is_valid = _find_valid_params(
    396         rng_key, exit_early=True
    397     )
    398 else:
    399     (init_params, pe, z_grad), is_valid = lax.map(_find_valid_params, rng_key)

File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/infer/util.py:381, in find_valid_initial_params.<locals>._find_valid_params(rng_key, exit_early)
    377 init_state = (0, rng_key, (prototype_params, 0.0, prototype_params), False)
    378 if exit_early and not_jax_tracer(rng_key):
    379     # Early return if valid params found. This is only helpful for single chain,
    380     # where we can avoid compiling body_fn in while_loop.
--> 381     _, _, (init_params, pe, z_grad), is_valid = init_state = body_fn(init_state)
    382     if not_jax_tracer(is_valid):
    383         if device_get(is_valid):

File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/infer/util.py:366, in find_valid_initial_params.<locals>.body_fn(state)
    364     z_grad = jacfwd(potential_fn)(params)
    365 else:
--> 366     pe, z_grad = value_and_grad(potential_fn)(params)
    367 z_grad_flat = ravel_pytree(z_grad)[0]
    368 is_valid = jnp.isfinite(pe) & jnp.all(jnp.isfinite(z_grad_flat))

File ~/.conda/envs/iub1/lib/python3.8/site-packages/jax/_src/traceback_util.py:139, in api_boundary.<locals>.reraise_with_filtered_traceback(*args, **kwargs)
    136 @util.wraps(fun)
    137 def reraise_with_filtered_traceback(*args, **kwargs):
    138   try:
--> 139     return fun(*args, **kwargs)
    140   except Exception as e:
    141     if not is_under_reraiser(e):

File ~/.conda/envs/iub1/lib/python3.8/site-packages/jax/api.py:815, in value_and_grad.<locals>.value_and_grad_f(*args, **kwargs)
    813 tree_map(partial(_check_input_dtype_grad, holomorphic, allow_int), dyn_args)
    814 if not has_aux:
--> 815   ans, vjp_py = _vjp(f_partial, *dyn_args)
    816 else:
    817   ans, vjp_py, aux = _vjp(f_partial, *dyn_args, has_aux=True)

File ~/.conda/envs/iub1/lib/python3.8/site-packages/jax/api.py:1888, in _vjp(fun, has_aux, *primals)
   1886 if not has_aux:
   1887   flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
-> 1888   out_primal, out_vjp = ad.vjp(flat_fun, primals_flat)
   1889   out_tree = out_tree()
   1890 else:

File ~/.conda/envs/iub1/lib/python3.8/site-packages/jax/interpreters/ad.py:114, in vjp(traceable, primals, has_aux)
    112 def vjp(traceable, primals, has_aux=False):
    113   if not has_aux:
--> 114     out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
    115   else:
    116     out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)

File ~/.conda/envs/iub1/lib/python3.8/site-packages/jax/interpreters/ad.py:101, in linearize(traceable, *primals, **kwargs)
     99 _, in_tree = tree_flatten(((primals, primals), {}))
    100 jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree)
--> 101 jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
    102 out_primals_pvals, out_tangents_pvals = tree_unflatten(out_tree(), out_pvals)
    103 assert all(out_primal_pval.is_known() for out_primal_pval in out_primals_pvals)

File ~/.conda/envs/iub1/lib/python3.8/site-packages/jax/interpreters/partial_eval.py:510, in trace_to_jaxpr(fun, pvals, instantiate)
    508 with core.new_main(JaxprTrace) as main:
    509   fun = trace_to_subjaxpr(fun, main, instantiate)
--> 510   jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
    511   assert not env
    512   del main, fun, env

File ~/.conda/envs/iub1/lib/python3.8/site-packages/jax/linear_util.py:166, in WrappedFun.call_wrapped(self, *args, **kwargs)
    163 gen = gen_static_args = out_store = None
    165 try:
--> 166   ans = self.f(*args, **dict(self.params, **kwargs))
    167 except:
    168   # Some transformations yield from inside context managers, so we have to
    169   # interrupt them before reraising the exception. Otherwise they will only
    170   # get garbage-collected at some later time, running their cleanup tasks only
    171   # after this exception is handled, which can corrupt the global state.
    172   while stack:

File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/infer/util.py:248, in potential_energy(model, model_args, model_kwargs, params, enum)
    244 substituted_model = substitute(
    245     model, substitute_fn=partial(_unconstrain_reparam, params)
    246 )
    247 # no param is needed for log_density computation because we already substitute
--> 248 log_joint, model_trace = log_density_(
    249     substituted_model, model_args, model_kwargs, {}
    250 )
    251 return -log_joint

File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/contrib/funsor/infer_util.py:270, in log_density(model, model_args, model_kwargs, params)
    249 def log_density(model, model_args, model_kwargs, params):
    250     """
    251     Similar to :func:`numpyro.infer.util.log_density` but works for models
    252     with discrete latent variables. Internally, this uses :mod:`funsor`
   (...)
    268     :return: log of joint density and a corresponding model trace
    269     """
--> 270     result, model_trace, _ = _enum_log_density(
    271         model, model_args, model_kwargs, params, funsor.ops.logaddexp, funsor.ops.add
    272     )
    273     return result.data, model_trace

File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/contrib/funsor/infer_util.py:159, in _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op)
    157 model = substitute(model, data=params)
    158 with plate_to_enum_plate():
--> 159     model_trace = packed_trace(model).get_trace(*model_args, **model_kwargs)
    160 log_factors = []
    161 time_to_factors = defaultdict(list)  # log prob factors

File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/handlers.py:171, in trace.get_trace(self, *args, **kwargs)
    163 def get_trace(self, *args, **kwargs):
    164     """
    165     Run the wrapped callable and return the recorded trace.
    166 
   (...)
    169     :return: `OrderedDict` containing the execution trace.
    170     """
--> 171     self(*args, **kwargs)
    172     return self.trace

File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/primitives.py:87, in Messenger.__call__(self, *args, **kwargs)
     85     return self
     86 with self:
---> 87     return self.fn(*args, **kwargs)

File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/primitives.py:87, in Messenger.__call__(self, *args, **kwargs)
     85     return self
     86 with self:
---> 87     return self.fn(*args, **kwargs)

    [... skipping similar frames: Messenger.__call__ at line 87 (4 times)]

File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/primitives.py:87, in Messenger.__call__(self, *args, **kwargs)
     85     return self
     86 with self:
---> 87     return self.fn(*args, **kwargs)

Input In [11], in tim_dont_overfit_model(X, t)
     16 # Y = coeffs * X
     17 print(n_variables, to_include.shape, X.shape, coeffs.shape)
---> 18 Y = jnp.dot(to_include * X, coeffs)
     20 # Z = Y - mean(Y)
     21 Z = jnp.where(Y - jnp.mean(Y) < 0, 1., 0.)

File ~/.conda/envs/iub1/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:5333, in _defer_to_unrecognized_arg.<locals>.deferring_binary_op(self, other)
   5331 if not isinstance(other, _scalar_types + _arraylike_types + (core.Tracer,)):
   5332   return NotImplemented
-> 5333 return binary_op(self, other)

File ~/.conda/envs/iub1/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:396, in _maybe_bool_binop.<locals>.fn(x1, x2)
    394 def fn(x1, x2):
    395   x1, x2 = _promote_args(numpy_fn.__name__, x1, x2)
--> 396   return lax_fn(x1, x2) if x1.dtype != bool_ else bool_lax_fn(x1, x2)

File ~/.conda/envs/iub1/lib/python3.8/site-packages/jax/_src/lax/lax.py:352, in mul(x, y)
    350 def mul(x: Array, y: Array) -> Array:
    351   r"""Elementwise multiplication: :math:`x \times y`."""
--> 352   return mul_p.bind(x, y)

File ~/.conda/envs/iub1/lib/python3.8/site-packages/jax/core.py:259, in Primitive.bind(self, *args, **params)
    257 top_trace = find_top_trace(args)
    258 tracers = map(top_trace.full_raise, args)
--> 259 out = top_trace.process_primitive(self, tracers, params)
    260 return map(full_lower, out) if self.multiple_results else full_lower(out)

File ~/.conda/envs/iub1/lib/python3.8/site-packages/jax/core.py:597, in EvalTrace.process_primitive(self, primitive, tracers, params)
    596 def process_primitive(self, primitive, tracers, params):
--> 597   return primitive.impl(*tracers, **params)

File ~/.conda/envs/iub1/lib/python3.8/site-packages/jax/interpreters/xla.py:230, in apply_primitive(prim, *args, **params)
    228 def apply_primitive(prim, *args, **params):
    229   """Impl rule that compiles and runs a single primitive 'prim' using XLA."""
--> 230   compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)
    231   return compiled_fun(*args)

File ~/.conda/envs/iub1/lib/python3.8/site-packages/jax/_src/util.py:197, in cache.<locals>.wrap.<locals>.wrapper(*args, **kwargs)
    195   return f(*args, **kwargs)
    196 else:
--> 197   return cached(bool(config.x64_enabled), *args, **kwargs)

File ~/.conda/envs/iub1/lib/python3.8/site-packages/jax/_src/util.py:190, in cache.<locals>.wrap.<locals>.cached(_, *args, **kwargs)
    188 @functools.lru_cache(max_size)
    189 def cached(_, *args, **kwargs):
--> 190   return f(*args, **kwargs)

File ~/.conda/envs/iub1/lib/python3.8/site-packages/jax/interpreters/xla.py:255, in xla_primitive_callable(prim, *arg_specs, **params)
    252     return prim.bind(*args, **params)
    253   return _xla_callable(lu.wrap_init(prim_fun), device, None, "prim", donated_invars,
    254                        *arg_specs)
--> 255 aval_out = prim.abstract_eval(*avals, **params)
    256 if not prim.multiple_results:
    257   handle_result = aval_to_result_handler(device, aval_out)

File ~/.conda/envs/iub1/lib/python3.8/site-packages/jax/_src/lax/lax.py:2010, in standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule, named_shape_rule, *avals, **kwargs)
   2007   return ConcreteArray(prim.impl(*[x.val for x in avals], **kwargs),
   2008                        weak_type=weak_type)
   2009 elif least_specialized is ShapedArray:
-> 2010   return ShapedArray(shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs),
   2011                      weak_type=weak_type,
   2012                      named_shape=named_shape_rule(*avals, **kwargs))
   2013 elif least_specialized is UnshapedArray:
   2014   return UnshapedArray(dtype_rule(*avals, **kwargs), weak_type=weak_type)

File ~/.conda/envs/iub1/lib/python3.8/site-packages/jax/_src/lax/lax.py:2106, in _broadcasting_shape_rule(name, *avals)
   2104 if result_shape is None:
   2105   msg = '{} got incompatible shapes for broadcasting: {}.'
-> 2106   raise TypeError(msg.format(name, ', '.join(map(str, map(tuple, shapes)))))
   2107 return result_shape

TypeError: mul got incompatible shapes for broadcasting: (2, 1), (250, 200).

This is my first time with Numpyro and probablistics programming. Sorry for the long read and thank you for any help!

I think you need to use DiscreteHMCGibbs(NUTS(...)) for your model. NUTS algorithms do not work for models with discrete latent variables (unless you want to use enumeration, as in the warning message - but enumeration won’t work for your model)

It works, thank you!