ValueError: Incompatible shapes for broadcasting: shapes=[(2,), (5,)]

Hi, i’m trying to learn multiple (5) coefficients for each treatment group.
So i set up my code so that for each “treatment” in the “treatment” plate, i draw 5 numbers from a normal distribution independently.

I’m expecting a final shape of (2, 5) but my implementation keeps throwing error about broadcasting. could you pls help.

import numpyro
import numpyro.distributions as dist
from numpyro.handlers import condition, seed, substitute, trace
from jax import random
import jax.numpy as jnp

rng_key = random.PRNGKey(0)    
with handlers.seed(rng_seed=rng_key):
    with numpyro.plate('Treatments', 2):
        numpyro.sample('Beta_CpGs', dist.Normal(jnp.zeros(5), jnp.ones(5)))

If you’re expecting each of the coefficients within a treatment group to be independent, then you should add another plate statement indicating that.

with numpyro.plate('Treatments', 2 , dim=-2):
   with numpyro.plate('coefficients', 5, dim=-1):
      numpyro.sample('Beta_CpGs', dist.Normal(0, 1))      
1 Like