Error with nested plates + vectorized distribution sampling

Hi – the following is a part of a more complex model but it captures the error I’m getting. I’ve read all the questions here that seem related but I still don’t understand what’s causing the error. I should be getting an array of shape (3,2) == 3 length scales for each 2 independent GRF samples.

import jax
from jax import numpy as jnp
import numpyro
import numpyro.distributions as dist

mesh_extents = jnp.array([200, 300, 400])
def model(mesh_extents):
    n_dim = len(mesh_extents)
    with numpyro.plate('GRFs', 2):
        with numpyro.plate('dimensions', n_dim):
            length_scales = numpyro.sample(
                'length_scales',
                dist.Uniform(mesh_extents/4, mesh_extents)
            )
    return length_scales

# try sampling
numpyro.handlers.seed(model, jax.random.key(1))(mesh_extents)

which throws

Traceback (most recent call last):
  File "[...]/.venv/lib/python3.13/site-packages/jax/_src/lax/lax.py", line 216, in _broadcast_shapes_uncached
    return _try_broadcast_shapes(*rank_promoted_shapes, name='broadcast_shapes')
  File "[...]/.venv/lib/python3.13/site-packages/jax/_src/lax/lax.py", line 133, in _try_broadcast_shapes
    raise TypeError(f'{name} got incompatible shapes for broadcasting: '
                    f'{", ".join(map(str, map(tuple, shapes)))}.')
TypeError: broadcast_shapes got incompatible shapes for broadcasting: (3, 2), (3, 3).

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "[...]/.venv/lib/python3.13/site-packages/jax/_src/lax/lax.py", line 194, in broadcast_shapes
    return _broadcast_shapes_cached(*shapes)
  File "[...]/.venv/lib/python3.13/site-packages/jax/_src/util.py", line 298, in wrapper
    return cached(config.trace_context() if trace_context_in_key else _ignore(),
                  *args, **kwargs)
  File "[...]/.venv/lib/python3.13/site-packages/jax/_src/util.py", line 292, in cached
    return f(*args, **kwargs)
  File "[...]/.venv/lib/python3.13/site-packages/jax/_src/lax/lax.py", line 200, in _broadcast_shapes_cached
    return _broadcast_shapes_uncached(*shapes)
  File "[...]/.venv/lib/python3.13/site-packages/jax/_src/lax/lax.py", line 219, in _broadcast_shapes_uncached
    raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}") from err
ValueError: Incompatible shapes for broadcasting: shapes=[(3, 2), (3, 3)]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "[...]/.venv/lib/python3.13/site-packages/jax/_src/lax/lax.py", line 216, in _broadcast_shapes_uncached
    return _try_broadcast_shapes(*rank_promoted_shapes, name='broadcast_shapes')
  File "[...]/.venv/lib/python3.13/site-packages/jax/_src/lax/lax.py", line 133, in _try_broadcast_shapes
    raise TypeError(f'{name} got incompatible shapes for broadcasting: '
                    f'{", ".join(map(str, map(tuple, shapes)))}.')
TypeError: broadcast_shapes got incompatible shapes for broadcasting: (3, 2), (3, 3).

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<python-input-0>", line 17, in <module>
    numpyro.handlers.seed(model, jax.random.key(1))(mesh_extents)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^
  File "[...]/.venv/lib/python3.13/site-packages/numpyro/handlers.py", line 849, in __call__
    return cloned_seeded_fn.__call__(*args, **kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "[...]/.venv/lib/python3.13/site-packages/numpyro/handlers.py", line 850, in __call__
    return super().__call__(*args, **kwargs)
           ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "[...]/.venv/lib/python3.13/site-packages/numpyro/primitives.py", line 119, in __call__
    return self.fn(*args, **kwargs)
           ~~~~~~~^^^^^^^^^^^^^^^^^
  File "<python-input-0>", line 10, in model
    length_scales = numpyro.sample(
        'length_scales',
        dist.Uniform(mesh_extents/4, mesh_extents)
    )
  File "[...]/.venv/lib/python3.13/site-packages/numpyro/primitives.py", line 248, in sample
    msg = apply_stack(initial_msg)
  File "[...]/.venv/lib/python3.13/site-packages/numpyro/primitives.py", line 53, in apply_stack
    handler.process_message(msg)
    ~~~~~~~~~~~~~~~~~~~~~~~^^^^^
  File "[...]/.venv/lib/python3.13/site-packages/numpyro/primitives.py", line 594, in process_message
    broadcast_shape = lax.broadcast_shapes(
        trailing_shape, tuple(dist_batch_shape)
    )
  File "[...]/.venv/lib/python3.13/site-packages/jax/_src/lax/lax.py", line 196, in broadcast_shapes
    return _broadcast_shapes_uncached(*shapes)
  File "[...]/.venv/lib/python3.13/site-packages/jax/_src/lax/lax.py", line 219, in _broadcast_shapes_uncached
    raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}") from err
ValueError: Incompatible shapes for broadcasting: shapes=[(3, 2), (3, 3)]

It works if you replace the nested plate with something like

length_scales = numpyro.sample(
    'length_scales',
    dist.Uniform(mesh_extents/4, mesh_extents).expand((n_dim,)).to_event(1)
)

but then the conditional independence information is lost. And, I can’t reproduce now, but at some point I was getting warnings about a missing plate when doing this.

maybe you can declare the dim in the plates? plate(..., dim=-1)