Hi – the following is a part of a more complex model but it captures the error I’m getting. I’ve read all the questions here that seem related but I still don’t understand what’s causing the error. I should be getting an array of shape (3,2) == 3 length scales for each 2 independent GRF samples.
import jax
from jax import numpy as jnp
import numpyro
import numpyro.distributions as dist
mesh_extents = jnp.array([200, 300, 400])
def model(mesh_extents):
n_dim = len(mesh_extents)
with numpyro.plate('GRFs', 2):
with numpyro.plate('dimensions', n_dim):
length_scales = numpyro.sample(
'length_scales',
dist.Uniform(mesh_extents/4, mesh_extents)
)
return length_scales
# try sampling
numpyro.handlers.seed(model, jax.random.key(1))(mesh_extents)
which throws
Traceback (most recent call last):
File "[...]/.venv/lib/python3.13/site-packages/jax/_src/lax/lax.py", line 216, in _broadcast_shapes_uncached
return _try_broadcast_shapes(*rank_promoted_shapes, name='broadcast_shapes')
File "[...]/.venv/lib/python3.13/site-packages/jax/_src/lax/lax.py", line 133, in _try_broadcast_shapes
raise TypeError(f'{name} got incompatible shapes for broadcasting: '
f'{", ".join(map(str, map(tuple, shapes)))}.')
TypeError: broadcast_shapes got incompatible shapes for broadcasting: (3, 2), (3, 3).
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "[...]/.venv/lib/python3.13/site-packages/jax/_src/lax/lax.py", line 194, in broadcast_shapes
return _broadcast_shapes_cached(*shapes)
File "[...]/.venv/lib/python3.13/site-packages/jax/_src/util.py", line 298, in wrapper
return cached(config.trace_context() if trace_context_in_key else _ignore(),
*args, **kwargs)
File "[...]/.venv/lib/python3.13/site-packages/jax/_src/util.py", line 292, in cached
return f(*args, **kwargs)
File "[...]/.venv/lib/python3.13/site-packages/jax/_src/lax/lax.py", line 200, in _broadcast_shapes_cached
return _broadcast_shapes_uncached(*shapes)
File "[...]/.venv/lib/python3.13/site-packages/jax/_src/lax/lax.py", line 219, in _broadcast_shapes_uncached
raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}") from err
ValueError: Incompatible shapes for broadcasting: shapes=[(3, 2), (3, 3)]
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "[...]/.venv/lib/python3.13/site-packages/jax/_src/lax/lax.py", line 216, in _broadcast_shapes_uncached
return _try_broadcast_shapes(*rank_promoted_shapes, name='broadcast_shapes')
File "[...]/.venv/lib/python3.13/site-packages/jax/_src/lax/lax.py", line 133, in _try_broadcast_shapes
raise TypeError(f'{name} got incompatible shapes for broadcasting: '
f'{", ".join(map(str, map(tuple, shapes)))}.')
TypeError: broadcast_shapes got incompatible shapes for broadcasting: (3, 2), (3, 3).
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<python-input-0>", line 17, in <module>
numpyro.handlers.seed(model, jax.random.key(1))(mesh_extents)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^
File "[...]/.venv/lib/python3.13/site-packages/numpyro/handlers.py", line 849, in __call__
return cloned_seeded_fn.__call__(*args, **kwargs)
~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "[...]/.venv/lib/python3.13/site-packages/numpyro/handlers.py", line 850, in __call__
return super().__call__(*args, **kwargs)
~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "[...]/.venv/lib/python3.13/site-packages/numpyro/primitives.py", line 119, in __call__
return self.fn(*args, **kwargs)
~~~~~~~^^^^^^^^^^^^^^^^^
File "<python-input-0>", line 10, in model
length_scales = numpyro.sample(
'length_scales',
dist.Uniform(mesh_extents/4, mesh_extents)
)
File "[...]/.venv/lib/python3.13/site-packages/numpyro/primitives.py", line 248, in sample
msg = apply_stack(initial_msg)
File "[...]/.venv/lib/python3.13/site-packages/numpyro/primitives.py", line 53, in apply_stack
handler.process_message(msg)
~~~~~~~~~~~~~~~~~~~~~~~^^^^^
File "[...]/.venv/lib/python3.13/site-packages/numpyro/primitives.py", line 594, in process_message
broadcast_shape = lax.broadcast_shapes(
trailing_shape, tuple(dist_batch_shape)
)
File "[...]/.venv/lib/python3.13/site-packages/jax/_src/lax/lax.py", line 196, in broadcast_shapes
return _broadcast_shapes_uncached(*shapes)
File "[...]/.venv/lib/python3.13/site-packages/jax/_src/lax/lax.py", line 219, in _broadcast_shapes_uncached
raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}") from err
ValueError: Incompatible shapes for broadcasting: shapes=[(3, 2), (3, 3)]