Broadcasting Error with ⁠numpyro.plate_stack and ⁠dist.Normal - Incompatible Shapes

I’m encountering a broadcasting error when trying to sample from a ⁠Normal distribution inside a ⁠numpyro.plate_stack context. The error occurs when the location parameter (⁠mu) has shape ⁠(5, 100) and the scale parameter (⁠sigma) has shape ⁠(5,) due to the plate context.

Minimal Reproducible Example

import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist

def some_transformation(x):
    return jnp.squeeze(x, axis=-1)

def simple_model(x, y=None):
    mu = some_transformation(x)  # Shape: (num_particles, 100)
    sigma = numpyro.sample("sigma", dist.LogNormal(0, 1.0))  # Shape: (num_particles,)
    
    print(f"mu shape: {mu.shape}, sigma shape: {sigma.shape}")
    
    # This causes the broadcasting error
    return numpyro.sample("y", dist.Normal(mu, sigma), obs=y)

def vectorized(fn, *shape, name="vectorization_plate"):
    def wrapper_fn(*args, **kwargs):
        with numpyro.plate_stack(name, [*shape]):
            return fn(*args, **kwargs)
    return wrapper_fn

# Test case
num_particles = 5
x_dummy = jnp.zeros((num_particles, 100, 1))

trace = numpyro.handlers.trace(
    numpyro.handlers.seed(vectorized(simple_model, num_particles), jax.random.key(0))
).get_trace(x_dummy)

Error Message

ValueError: Incompatible shapes for broadcasting: shapes=[(), (5, 100), (5,)]

I expect ⁠dist.Normal(mu, sigma) to create a distribution where

  • mu has shape ⁠(num_particles, 100)
  • sigma with shape ⁠(num_particles,) broadcasts to ⁠(num_particles, 1) and then to ⁠(num_particles, 100)
  • The resulting distribution should have batch shape ⁠(num_particles, 100)

Thank you in advance.