I’m encountering a broadcasting error when trying to sample from a Normal distribution inside a numpyro.plate_stack context. The error occurs when the location parameter (mu) has shape (5, 100) and the scale parameter (sigma) has shape (5,) due to the plate context.
Minimal Reproducible Example
import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
def some_transformation(x):
return jnp.squeeze(x, axis=-1)
def simple_model(x, y=None):
mu = some_transformation(x) # Shape: (num_particles, 100)
sigma = numpyro.sample("sigma", dist.LogNormal(0, 1.0)) # Shape: (num_particles,)
print(f"mu shape: {mu.shape}, sigma shape: {sigma.shape}")
# This causes the broadcasting error
return numpyro.sample("y", dist.Normal(mu, sigma), obs=y)
def vectorized(fn, *shape, name="vectorization_plate"):
def wrapper_fn(*args, **kwargs):
with numpyro.plate_stack(name, [*shape]):
return fn(*args, **kwargs)
return wrapper_fn
# Test case
num_particles = 5
x_dummy = jnp.zeros((num_particles, 100, 1))
trace = numpyro.handlers.trace(
numpyro.handlers.seed(vectorized(simple_model, num_particles), jax.random.key(0))
).get_trace(x_dummy)
Error Message
ValueError: Incompatible shapes for broadcasting: shapes=[(), (5, 100), (5,)]
I expect dist.Normal(mu, sigma)
to create a distribution where
- mu has shape (num_particles, 100)
- sigma with shape (num_particles,) broadcasts to (num_particles, 1) and then to (num_particles, 100)
- The resulting distribution should have batch shape (num_particles, 100)
Thank you in advance.