I am having trouble getting the shapes and slices in a model with the following overall structure to work:
with numpyro.plate('locus', l, dim=-1):
omega = numpyro.sample("omega", dist.Exponential(0.5))
alpha = vmap_gen_alpha(omega, A, pimat, pimult, pimatinv, scale)
with numpyro.plate('ancestor', 61, dim=-2):
numpyro.sample('obs', dist.DirichletMultinomial(concentration=alpha, total_count=N), obs=obs_mat)
(full code is here, in particular this section)
I have:
-
l
= 20 (loci) -
N
is a vector of length 20. -
obs_mat
has shape(61, 20)
-
vmap_gen_alpha
is a function fromjax.vmap
which is vectorised over omega which would return a 61x61 matrix for a single omega input, such that this call returns analpha
with shape(20, 61, 61)
. All other parameters are constant in this plate.
Without the plates (and in numpy notation), I would like to run something like:
for locus in range(l):
omega[i] = numpyro.sample(f"omega_{i}", dist.Exponential(0.5))
alpha = gen_alpha(omega[i], A, pimat, pimult, pimatinv, scale)
for codon in range(61):
numpyro.sample('obs', dist.DirichletMultinomial(concentration=alpha[codon, :], total_count=N[locus]), obs=obs_mat[locus, :])
But getting the plate dims/shapes right in the second level of nesting is eluding me.
ValueError: Incompatible shapes for broadcasting: shapes=[(20, 61), (20,)]
Within a doubly nested plate, how do I:
- Obtain vector input for the multinomial concentration and obs, over the correct dimensions
- Use different slices/dimensions for alpha and N/obs_mat?
I’ve tried reading pyro’s guide to tensor shapes, but couldn’t work out how to apply this to my problem.
I would be very grateful for any advice given!