Hi, I’m using Numpyro version 0.9.0 to reimplement Tim Salimans’s solution for the challenge Don’t Overfit on Kaggle for learning purpose.
I’m having trouble with the implementation. The jax.numpy doesn’t allow me to multiple the to_include
variable with X
due to shape mismatched.
Here is my implementation:
from jax import random
import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
def tim_dont_overfit_model(X, t):
"""
X: a numpy array of shape [n, 200] contains explanatory variables.
t: a numpy array of shape [n,] contains dependent variables.
"""
_, n_variables = X.shape
with numpyro.plate('vars', n_variables):
# Each of these variables have a coefficient associated with them.
coeffs = numpyro.sample('coefficient', dist.Uniform(0, 1))
# In those 200 variables, there are roughly half of them
# are used to create the target variable.
to_include = numpyro.sample('to_include', dist.Bernoulli(0.5))
# Y = coeffs * X
print(n_variables, to_include.shape, X.shape, coeffs.shape)
Y = jnp.dot(to_include * X, coeffs)
# Z = Y - mean(Y)
Z = jnp.where(Y - jnp.mean(Y) < 0, 1., 0.)
with numpyro.plate('obs', len(t)):
numpyro.sample('t', dist.Bernoulli(Z), obs=t)
# Start from this source of randomness. We will split keys for subsequent operations.
rng_key = random.PRNGKey(0)
# Run NUTS.
kernel = NUTS(tim_dont_overfit_model)
num_samples = 50000
mcmc = MCMC(kernel, num_warmup=1000, num_samples=num_samples)
mcmc.run(
rng_key,
X=np.random.rand(250, 200),
t=np.random.rand(250),
)
mcmc.print_summary()
And here is the output with errors:
200 (200,) (250, 200) (200,)
200 (2, 1) (250, 200) (200,)
/tmp/ipykernel_23613/978136461.py:8: FutureWarning: Some algorithms will automatically enumerate the discrete latent site to_include of your model. In the future, enumerated sites need to be marked with `infer={'enumerate': 'parallel'}`.
mcmc.run(
---------------------------------------------------------------------------
FilteredStackTrace Traceback (most recent call last)
Input In [12], in <module>
7 mcmc = MCMC(kernel, num_warmup=1000, num_samples=num_samples)
----> 8 mcmc.run(
9 rng_key,
10 X=np.random.rand(250, 200),
11 t=np.random.rand(250),
12 )
13 mcmc.print_summary()
File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/infer/mcmc.py:599, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs)
598 if self.num_chains == 1:
--> 599 states_flat, last_state = partial_map_fn(map_args)
600 states = tree_map(lambda x: x[jnp.newaxis, ...], states_flat)
File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/infer/mcmc.py:387, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields)
386 if init_state is None:
--> 387 init_state = self.sampler.init(
388 rng_key,
389 self.num_warmup,
390 init_params,
391 model_args=args,
392 model_kwargs=kwargs,
393 )
394 sample_fn, postprocess_fn = self._get_cached_fns()
File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/infer/hmc.py:696, in HMC.init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
693 rng_key, rng_key_init_model = jnp.swapaxes(
694 vmap(random.split)(rng_key), 0, 1
695 )
--> 696 init_params = self._init_state(
697 rng_key_init_model, model_args, model_kwargs, init_params
698 )
699 if self._potential_fn and init_params is None:
File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/infer/hmc.py:642, in HMC._init_state(self, rng_key, model_args, model_kwargs, init_params)
641 if self._model is not None:
--> 642 init_params, potential_fn, postprocess_fn, model_trace = initialize_model(
643 rng_key,
644 self._model,
645 dynamic_args=True,
646 init_strategy=self._init_strategy,
647 model_args=model_args,
648 model_kwargs=model_kwargs,
649 forward_mode_differentiation=self._forward_mode_differentiation,
650 )
651 if self._init_fn is None:
File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/infer/util.py:654, in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad)
653 prototype_params = transform_fn(inv_transforms, constrained_values, invert=True)
--> 654 (init_params, pe, grad), is_valid = find_valid_initial_params(
655 rng_key,
656 substitute(
657 model,
658 data={
659 k: site["value"]
660 for k, site in model_trace.items()
661 if site["type"] in ["plate"]
662 },
663 ),
664 init_strategy=init_strategy,
665 enum=has_enumerate_support,
666 model_args=model_args,
667 model_kwargs=model_kwargs,
668 prototype_params=prototype_params,
669 forward_mode_differentiation=forward_mode_differentiation,
670 validate_grad=validate_grad,
671 )
673 if not_jax_tracer(is_valid):
File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/infer/util.py:395, in find_valid_initial_params(rng_key, model, init_strategy, enum, model_args, model_kwargs, prototype_params, forward_mode_differentiation, validate_grad)
394 if rng_key.ndim == 1:
--> 395 (init_params, pe, z_grad), is_valid = _find_valid_params(
396 rng_key, exit_early=True
397 )
398 else:
File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/infer/util.py:381, in find_valid_initial_params.<locals>._find_valid_params(rng_key, exit_early)
378 if exit_early and not_jax_tracer(rng_key):
379 # Early return if valid params found. This is only helpful for single chain,
380 # where we can avoid compiling body_fn in while_loop.
--> 381 _, _, (init_params, pe, z_grad), is_valid = init_state = body_fn(init_state)
382 if not_jax_tracer(is_valid):
File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/infer/util.py:366, in find_valid_initial_params.<locals>.body_fn(state)
365 else:
--> 366 pe, z_grad = value_and_grad(potential_fn)(params)
367 z_grad_flat = ravel_pytree(z_grad)[0]
File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/infer/util.py:248, in potential_energy(model, model_args, model_kwargs, params, enum)
247 # no param is needed for log_density computation because we already substitute
--> 248 log_joint, model_trace = log_density_(
249 substituted_model, model_args, model_kwargs, {}
250 )
251 return -log_joint
File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/contrib/funsor/infer_util.py:270, in log_density(model, model_args, model_kwargs, params)
250 """
251 Similar to :func:`numpyro.infer.util.log_density` but works for models
252 with discrete latent variables. Internally, this uses :mod:`funsor`
(...)
268 :return: log of joint density and a corresponding model trace
269 """
--> 270 result, model_trace, _ = _enum_log_density(
271 model, model_args, model_kwargs, params, funsor.ops.logaddexp, funsor.ops.add
272 )
273 return result.data, model_trace
File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/contrib/funsor/infer_util.py:159, in _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op)
158 with plate_to_enum_plate():
--> 159 model_trace = packed_trace(model).get_trace(*model_args, **model_kwargs)
160 log_factors = []
File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/handlers.py:171, in trace.get_trace(self, *args, **kwargs)
164 """
165 Run the wrapped callable and return the recorded trace.
166
(...)
169 :return: `OrderedDict` containing the execution trace.
170 """
--> 171 self(*args, **kwargs)
172 return self.trace
File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/primitives.py:87, in Messenger.__call__(self, *args, **kwargs)
86 with self:
---> 87 return self.fn(*args, **kwargs)
File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/primitives.py:87, in Messenger.__call__(self, *args, **kwargs)
86 with self:
---> 87 return self.fn(*args, **kwargs)
[... skipping similar frames: Messenger.__call__ at line 87 (4 times)]
File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/primitives.py:87, in Messenger.__call__(self, *args, **kwargs)
86 with self:
---> 87 return self.fn(*args, **kwargs)
Input In [11], in tim_dont_overfit_model(X, t)
17 print(n_variables, to_include.shape, X.shape, coeffs.shape)
---> 18 Y = jnp.dot(to_include * X, coeffs)
20 # Z = Y - mean(Y)
File ~/.conda/envs/iub1/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:5333, in _defer_to_unrecognized_arg.<locals>.deferring_binary_op(self, other)
5332 return NotImplemented
-> 5333 return binary_op(self, other)
File ~/.conda/envs/iub1/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:396, in _maybe_bool_binop.<locals>.fn(x1, x2)
395 x1, x2 = _promote_args(numpy_fn.__name__, x1, x2)
--> 396 return lax_fn(x1, x2) if x1.dtype != bool_ else bool_lax_fn(x1, x2)
FilteredStackTrace: TypeError: mul got incompatible shapes for broadcasting: (2, 1), (250, 200).
The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
TypeError Traceback (most recent call last)
Input In [12], in <module>
6 num_samples = 50000
7 mcmc = MCMC(kernel, num_warmup=1000, num_samples=num_samples)
----> 8 mcmc.run(
9 rng_key,
10 X=np.random.rand(250, 200),
11 t=np.random.rand(250),
12 )
13 mcmc.print_summary()
File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/infer/mcmc.py:599, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs)
597 map_args = (rng_key, init_state, init_params)
598 if self.num_chains == 1:
--> 599 states_flat, last_state = partial_map_fn(map_args)
600 states = tree_map(lambda x: x[jnp.newaxis, ...], states_flat)
601 else:
File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/infer/mcmc.py:387, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields)
385 rng_key, init_state, init_params = init
386 if init_state is None:
--> 387 init_state = self.sampler.init(
388 rng_key,
389 self.num_warmup,
390 init_params,
391 model_args=args,
392 model_kwargs=kwargs,
393 )
394 sample_fn, postprocess_fn = self._get_cached_fns()
395 diagnostics = (
396 lambda x: self.sampler.get_diagnostics_str(x[0])
397 if rng_key.ndim == 1
398 else ""
399 ) # noqa: E731
File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/infer/hmc.py:696, in HMC.init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
691 # vectorized
692 else:
693 rng_key, rng_key_init_model = jnp.swapaxes(
694 vmap(random.split)(rng_key), 0, 1
695 )
--> 696 init_params = self._init_state(
697 rng_key_init_model, model_args, model_kwargs, init_params
698 )
699 if self._potential_fn and init_params is None:
700 raise ValueError(
701 "Valid value of `init_params` must be provided with" " `potential_fn`."
702 )
File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/infer/hmc.py:642, in HMC._init_state(self, rng_key, model_args, model_kwargs, init_params)
640 def _init_state(self, rng_key, model_args, model_kwargs, init_params):
641 if self._model is not None:
--> 642 init_params, potential_fn, postprocess_fn, model_trace = initialize_model(
643 rng_key,
644 self._model,
645 dynamic_args=True,
646 init_strategy=self._init_strategy,
647 model_args=model_args,
648 model_kwargs=model_kwargs,
649 forward_mode_differentiation=self._forward_mode_differentiation,
650 )
651 if self._init_fn is None:
652 self._init_fn, self._sample_fn = hmc(
653 potential_fn_gen=potential_fn,
654 kinetic_fn=self._kinetic_fn,
655 algo=self._algo,
656 )
File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/infer/util.py:654, in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad)
652 init_strategy = _init_to_unconstrained_value(values=unconstrained_values)
653 prototype_params = transform_fn(inv_transforms, constrained_values, invert=True)
--> 654 (init_params, pe, grad), is_valid = find_valid_initial_params(
655 rng_key,
656 substitute(
657 model,
658 data={
659 k: site["value"]
660 for k, site in model_trace.items()
661 if site["type"] in ["plate"]
662 },
663 ),
664 init_strategy=init_strategy,
665 enum=has_enumerate_support,
666 model_args=model_args,
667 model_kwargs=model_kwargs,
668 prototype_params=prototype_params,
669 forward_mode_differentiation=forward_mode_differentiation,
670 validate_grad=validate_grad,
671 )
673 if not_jax_tracer(is_valid):
674 if device_get(~jnp.all(is_valid)):
File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/infer/util.py:395, in find_valid_initial_params(rng_key, model, init_strategy, enum, model_args, model_kwargs, prototype_params, forward_mode_differentiation, validate_grad)
393 # Handle possible vectorization
394 if rng_key.ndim == 1:
--> 395 (init_params, pe, z_grad), is_valid = _find_valid_params(
396 rng_key, exit_early=True
397 )
398 else:
399 (init_params, pe, z_grad), is_valid = lax.map(_find_valid_params, rng_key)
File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/infer/util.py:381, in find_valid_initial_params.<locals>._find_valid_params(rng_key, exit_early)
377 init_state = (0, rng_key, (prototype_params, 0.0, prototype_params), False)
378 if exit_early and not_jax_tracer(rng_key):
379 # Early return if valid params found. This is only helpful for single chain,
380 # where we can avoid compiling body_fn in while_loop.
--> 381 _, _, (init_params, pe, z_grad), is_valid = init_state = body_fn(init_state)
382 if not_jax_tracer(is_valid):
383 if device_get(is_valid):
File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/infer/util.py:366, in find_valid_initial_params.<locals>.body_fn(state)
364 z_grad = jacfwd(potential_fn)(params)
365 else:
--> 366 pe, z_grad = value_and_grad(potential_fn)(params)
367 z_grad_flat = ravel_pytree(z_grad)[0]
368 is_valid = jnp.isfinite(pe) & jnp.all(jnp.isfinite(z_grad_flat))
File ~/.conda/envs/iub1/lib/python3.8/site-packages/jax/_src/traceback_util.py:139, in api_boundary.<locals>.reraise_with_filtered_traceback(*args, **kwargs)
136 @util.wraps(fun)
137 def reraise_with_filtered_traceback(*args, **kwargs):
138 try:
--> 139 return fun(*args, **kwargs)
140 except Exception as e:
141 if not is_under_reraiser(e):
File ~/.conda/envs/iub1/lib/python3.8/site-packages/jax/api.py:815, in value_and_grad.<locals>.value_and_grad_f(*args, **kwargs)
813 tree_map(partial(_check_input_dtype_grad, holomorphic, allow_int), dyn_args)
814 if not has_aux:
--> 815 ans, vjp_py = _vjp(f_partial, *dyn_args)
816 else:
817 ans, vjp_py, aux = _vjp(f_partial, *dyn_args, has_aux=True)
File ~/.conda/envs/iub1/lib/python3.8/site-packages/jax/api.py:1888, in _vjp(fun, has_aux, *primals)
1886 if not has_aux:
1887 flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
-> 1888 out_primal, out_vjp = ad.vjp(flat_fun, primals_flat)
1889 out_tree = out_tree()
1890 else:
File ~/.conda/envs/iub1/lib/python3.8/site-packages/jax/interpreters/ad.py:114, in vjp(traceable, primals, has_aux)
112 def vjp(traceable, primals, has_aux=False):
113 if not has_aux:
--> 114 out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
115 else:
116 out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)
File ~/.conda/envs/iub1/lib/python3.8/site-packages/jax/interpreters/ad.py:101, in linearize(traceable, *primals, **kwargs)
99 _, in_tree = tree_flatten(((primals, primals), {}))
100 jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree)
--> 101 jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
102 out_primals_pvals, out_tangents_pvals = tree_unflatten(out_tree(), out_pvals)
103 assert all(out_primal_pval.is_known() for out_primal_pval in out_primals_pvals)
File ~/.conda/envs/iub1/lib/python3.8/site-packages/jax/interpreters/partial_eval.py:510, in trace_to_jaxpr(fun, pvals, instantiate)
508 with core.new_main(JaxprTrace) as main:
509 fun = trace_to_subjaxpr(fun, main, instantiate)
--> 510 jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
511 assert not env
512 del main, fun, env
File ~/.conda/envs/iub1/lib/python3.8/site-packages/jax/linear_util.py:166, in WrappedFun.call_wrapped(self, *args, **kwargs)
163 gen = gen_static_args = out_store = None
165 try:
--> 166 ans = self.f(*args, **dict(self.params, **kwargs))
167 except:
168 # Some transformations yield from inside context managers, so we have to
169 # interrupt them before reraising the exception. Otherwise they will only
170 # get garbage-collected at some later time, running their cleanup tasks only
171 # after this exception is handled, which can corrupt the global state.
172 while stack:
File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/infer/util.py:248, in potential_energy(model, model_args, model_kwargs, params, enum)
244 substituted_model = substitute(
245 model, substitute_fn=partial(_unconstrain_reparam, params)
246 )
247 # no param is needed for log_density computation because we already substitute
--> 248 log_joint, model_trace = log_density_(
249 substituted_model, model_args, model_kwargs, {}
250 )
251 return -log_joint
File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/contrib/funsor/infer_util.py:270, in log_density(model, model_args, model_kwargs, params)
249 def log_density(model, model_args, model_kwargs, params):
250 """
251 Similar to :func:`numpyro.infer.util.log_density` but works for models
252 with discrete latent variables. Internally, this uses :mod:`funsor`
(...)
268 :return: log of joint density and a corresponding model trace
269 """
--> 270 result, model_trace, _ = _enum_log_density(
271 model, model_args, model_kwargs, params, funsor.ops.logaddexp, funsor.ops.add
272 )
273 return result.data, model_trace
File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/contrib/funsor/infer_util.py:159, in _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op)
157 model = substitute(model, data=params)
158 with plate_to_enum_plate():
--> 159 model_trace = packed_trace(model).get_trace(*model_args, **model_kwargs)
160 log_factors = []
161 time_to_factors = defaultdict(list) # log prob factors
File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/handlers.py:171, in trace.get_trace(self, *args, **kwargs)
163 def get_trace(self, *args, **kwargs):
164 """
165 Run the wrapped callable and return the recorded trace.
166
(...)
169 :return: `OrderedDict` containing the execution trace.
170 """
--> 171 self(*args, **kwargs)
172 return self.trace
File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/primitives.py:87, in Messenger.__call__(self, *args, **kwargs)
85 return self
86 with self:
---> 87 return self.fn(*args, **kwargs)
File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/primitives.py:87, in Messenger.__call__(self, *args, **kwargs)
85 return self
86 with self:
---> 87 return self.fn(*args, **kwargs)
[... skipping similar frames: Messenger.__call__ at line 87 (4 times)]
File ~/.conda/envs/iub1/lib/python3.8/site-packages/numpyro/primitives.py:87, in Messenger.__call__(self, *args, **kwargs)
85 return self
86 with self:
---> 87 return self.fn(*args, **kwargs)
Input In [11], in tim_dont_overfit_model(X, t)
16 # Y = coeffs * X
17 print(n_variables, to_include.shape, X.shape, coeffs.shape)
---> 18 Y = jnp.dot(to_include * X, coeffs)
20 # Z = Y - mean(Y)
21 Z = jnp.where(Y - jnp.mean(Y) < 0, 1., 0.)
File ~/.conda/envs/iub1/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:5333, in _defer_to_unrecognized_arg.<locals>.deferring_binary_op(self, other)
5331 if not isinstance(other, _scalar_types + _arraylike_types + (core.Tracer,)):
5332 return NotImplemented
-> 5333 return binary_op(self, other)
File ~/.conda/envs/iub1/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:396, in _maybe_bool_binop.<locals>.fn(x1, x2)
394 def fn(x1, x2):
395 x1, x2 = _promote_args(numpy_fn.__name__, x1, x2)
--> 396 return lax_fn(x1, x2) if x1.dtype != bool_ else bool_lax_fn(x1, x2)
File ~/.conda/envs/iub1/lib/python3.8/site-packages/jax/_src/lax/lax.py:352, in mul(x, y)
350 def mul(x: Array, y: Array) -> Array:
351 r"""Elementwise multiplication: :math:`x \times y`."""
--> 352 return mul_p.bind(x, y)
File ~/.conda/envs/iub1/lib/python3.8/site-packages/jax/core.py:259, in Primitive.bind(self, *args, **params)
257 top_trace = find_top_trace(args)
258 tracers = map(top_trace.full_raise, args)
--> 259 out = top_trace.process_primitive(self, tracers, params)
260 return map(full_lower, out) if self.multiple_results else full_lower(out)
File ~/.conda/envs/iub1/lib/python3.8/site-packages/jax/core.py:597, in EvalTrace.process_primitive(self, primitive, tracers, params)
596 def process_primitive(self, primitive, tracers, params):
--> 597 return primitive.impl(*tracers, **params)
File ~/.conda/envs/iub1/lib/python3.8/site-packages/jax/interpreters/xla.py:230, in apply_primitive(prim, *args, **params)
228 def apply_primitive(prim, *args, **params):
229 """Impl rule that compiles and runs a single primitive 'prim' using XLA."""
--> 230 compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)
231 return compiled_fun(*args)
File ~/.conda/envs/iub1/lib/python3.8/site-packages/jax/_src/util.py:197, in cache.<locals>.wrap.<locals>.wrapper(*args, **kwargs)
195 return f(*args, **kwargs)
196 else:
--> 197 return cached(bool(config.x64_enabled), *args, **kwargs)
File ~/.conda/envs/iub1/lib/python3.8/site-packages/jax/_src/util.py:190, in cache.<locals>.wrap.<locals>.cached(_, *args, **kwargs)
188 @functools.lru_cache(max_size)
189 def cached(_, *args, **kwargs):
--> 190 return f(*args, **kwargs)
File ~/.conda/envs/iub1/lib/python3.8/site-packages/jax/interpreters/xla.py:255, in xla_primitive_callable(prim, *arg_specs, **params)
252 return prim.bind(*args, **params)
253 return _xla_callable(lu.wrap_init(prim_fun), device, None, "prim", donated_invars,
254 *arg_specs)
--> 255 aval_out = prim.abstract_eval(*avals, **params)
256 if not prim.multiple_results:
257 handle_result = aval_to_result_handler(device, aval_out)
File ~/.conda/envs/iub1/lib/python3.8/site-packages/jax/_src/lax/lax.py:2010, in standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule, named_shape_rule, *avals, **kwargs)
2007 return ConcreteArray(prim.impl(*[x.val for x in avals], **kwargs),
2008 weak_type=weak_type)
2009 elif least_specialized is ShapedArray:
-> 2010 return ShapedArray(shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs),
2011 weak_type=weak_type,
2012 named_shape=named_shape_rule(*avals, **kwargs))
2013 elif least_specialized is UnshapedArray:
2014 return UnshapedArray(dtype_rule(*avals, **kwargs), weak_type=weak_type)
File ~/.conda/envs/iub1/lib/python3.8/site-packages/jax/_src/lax/lax.py:2106, in _broadcasting_shape_rule(name, *avals)
2104 if result_shape is None:
2105 msg = '{} got incompatible shapes for broadcasting: {}.'
-> 2106 raise TypeError(msg.format(name, ', '.join(map(str, map(tuple, shapes)))))
2107 return result_shape
TypeError: mul got incompatible shapes for broadcasting: (2, 1), (250, 200).
This is my first time with Numpyro and probablistics programming. Sorry for the long read and thank you for any help!