Hi, i’m trying to learn multiple (5) coefficients for each treatment group.
So i set up my code so that for each “treatment” in the “treatment” plate, i draw 5 numbers from a normal distribution independently.
I’m expecting a final shape of (2, 5) but my implementation keeps throwing error about broadcasting. could you pls help.
import numpyro
import numpyro.distributions as dist
from numpyro.handlers import condition, seed, substitute, trace
from jax import random
import jax.numpy as jnp
rng_key = random.PRNGKey(0)
with handlers.seed(rng_seed=rng_key):
with numpyro.plate('Treatments', 2):
numpyro.sample('Beta_CpGs', dist.Normal(jnp.zeros(5), jnp.ones(5)))