Hi everyone,
Sorry if this is naive — I’m a relatively new user of NumPyro and Pyro.
I’m working on a hierarchical negative binomial model in NumPyro with multiple categorical random effects. When I run inference or predictive sampling on a test set, I get a broadcasting shape mismatch error:
ValueError: Incompatible shapes for broadcasting: shapes=[(train_size,), (test_size,)]
From debugging, it seems the problem happens inside the numpyro.sample()
statement for the observed data, where the shape of the latent variables and the observations don’t align properly.
I also tried using the obs_mask
argument to mask observed vs. missing data, hoping to sample unobserved values separately, but this leads to a runtime error because obs_mask
currently does not support distributions like Poisson that lack enumeration support.
This makes it tricky to handle train/test splits where test data requires sampling predictive values without conditioning on observations.
Has anyone faced this limitation or found workarounds for handling posterior predictive sampling with differing categorical random effects sizes between train and test sets in NumPyro?
Thanks for any insights!
import numpy as np
import jax.numpy as jnp
from jax import random
import numpyro
import numpyro.distributions as dist
from numpyro.infer import SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoNormal
from numpyro.optim import Adam
from numpyro.infer import Predictive
# Minimal dummy data setup
N = 10 # number of data points
cat_sizes = {"cat1": 3, "cat2": 4}
# Categorical indices must have length N
cat_indices = {
"cat1": jnp.array([0, 1, 2, 0, 1, 2, 0, 1, 2, 0]), # length N=10
"cat2": jnp.array([0, 1, 2, 3, 0, 1, 2, 3, 0, 1]), # length N=10
}
# Observed counts (y) length N=10, no NaNs allowed
y = jnp.array([5, 3, 6, 2, 7, 4, 5, 3, 6, 2])
def nb_random_effects_model(y, cat_indices, cat_sizes):
sigma = {}
effects = {}
for col in cat_indices:
sigma[col] = numpyro.sample(f"sigma_{col}", dist.HalfCauchy(1.0))
r = numpyro.sample("r", dist.Exponential(1.0))
for col in cat_indices:
with numpyro.plate(col, cat_sizes[col]):
effects[col] = numpyro.sample(f"alpha_{col}", dist.Normal(0.0, sigma[col]))
eta = sum(effects[col][cat_indices[col]] for col in cat_indices)
mu = jnp.exp(eta)
with numpyro.plate("data", len(cat_indices[next(iter(cat_indices))])):
lam = numpyro.sample("lambda", dist.Gamma(r, r / mu))
if y is None:
print(len(mu))
print(len(lam))
numpyro.sample("obs", dist.Poisson(lam), obs=y)
rng_key = random.PRNGKey(0)
optimizer = Adam(1e-2)
guide = AutoNormal(nb_random_effects_model)
svi = SVI(nb_random_effects_model, guide, optimizer, loss=Trace_ELBO())
n_steps = 100
state = svi.init(rng_key, y, cat_indices, cat_sizes)
for step in range(n_steps):
state, loss = svi.update(state, y, cat_indices, cat_sizes)
if step % 100 == 0:
print(f"Step {step} loss: {loss:.2f}")
params = svi.get_params(state)
print("Learned parameters:")
for k, v in params.items():
print(f"{k}: {v.shape if hasattr(v, 'shape') else v}")
cat_indices_test = {
"cat1": jnp.array([0, 1, 0]), # length 3
"cat2": jnp.array([1, 0, 3]) # length 3
}
y_test = None # No observed data
predictive = Predictive(nb_random_effects_model, guide=guide, params=params, num_samples=10)
samples = predictive(rng_key, y_test, cat_indices_test, cat_sizes)```