Nesting plates with 2D data and multinomial samples

I am having trouble getting the shapes and slices in a model with the following overall structure to work:

with numpyro.plate('locus', l, dim=-1):
     omega = numpyro.sample("omega", dist.Exponential(0.5))
     alpha = vmap_gen_alpha(omega, A, pimat, pimult, pimatinv, scale)
     with numpyro.plate('ancestor', 61, dim=-2):
         numpyro.sample('obs', dist.DirichletMultinomial(concentration=alpha, total_count=N), obs=obs_mat)

(full code is here, in particular this section)

I have:

  • l = 20 (loci)
  • N is a vector of length 20.
  • obs_mat has shape (61, 20)
  • vmap_gen_alpha is a function from jax.vmap which is vectorised over omega which would return a 61x61 matrix for a single omega input, such that this call returns an alpha with shape (20, 61, 61). All other parameters are constant in this plate.

Without the plates (and in numpy notation), I would like to run something like:

for locus in range(l):
    omega[i] = numpyro.sample(f"omega_{i}", dist.Exponential(0.5))
    alpha = gen_alpha(omega[i], A, pimat, pimult, pimatinv, scale)
    for codon in range(61):
        numpyro.sample('obs', dist.DirichletMultinomial(concentration=alpha[codon, :], total_count=N[locus]), obs=obs_mat[locus, :])

But getting the plate dims/shapes right in the second level of nesting is eluding me.

ValueError: Incompatible shapes for broadcasting: shapes=[(20, 61), (20,)]

Within a doubly nested plate, how do I:

  • Obtain vector input for the multinomial concentration and obs, over the correct dimensions
  • Use different slices/dimensions for alpha and N/obs_mat?

I’ve tried reading pyro’s guide to tensor shapes, but couldn’t work out how to apply this to my problem.

I would be very grateful for any advice given!

hi john sounds like you need to reshape vmap_gen_alpha so that it has shape (61, 20, 61). as you nest inward in plates (w.r.t. python context managers) the dimension moves left

Hi Martin,

Thanks so much for your help, this got me on the right track. There were a few things going wrong in the plates (with all of alpha, N, and obs_mat). I went through these one by one, and the solution was:

  • Change the plate dims to be -2 then -1 (i.e. swap them)
  • For obs_mat, I needed to copy the data over 61 loops over the inner plate so it has shape (20, 61, 61).
  • For N, I also needed a copy so it has shape (20, 61).
  • alpha was actually ok with these dim changes, but I had some errors in the function generating it making all entries nan and causing an error about the concentration being invalid (and another function it called also needed vmap).

Looks like it’s running so this issue is sorted! (code)
Still having some efficiency problems, but that’s for another time

generally speaking you should never need to copy data since broadcasting should take care of things if you’ve set things up right

Ok, I can give that another look – but would you then have any advice for data with shape (20, 61) (or its transpose) entering a set of plates nested over 20, then 61. The data needs to be a (61, ) vector inside the second plate, and I actually don’t want it to be indexed/flattened into a scalar observation as the 61 loop isn’t over this data, but instead the alpha parameter which is (61, 61)

I currently do the broadcast as a transform before entering any jax/numpyro parts:
obs_mat[l, :, :] = np.broadcast_to(X[:, l], (61, 61))
Is it possible to move this into the numpyro section to avoid a copy do you think?

Perhaps a more simple example is N which has shape (20, ). I want this to be a scalar at the inner plate – but I get a shape error as it has no dimension with 61 to vectorise over. The only way I could work this out was by broadcasting it to (20, 61) similar to the above.

i’m not really sure i understand your question(s) but i’d recommend taking a look at some pyro/numpyro models with more complex plate structure to see how to deal with nested plates. for example:

Logistic growth models of SARS-CoV-2 lineage proportions — Pyro Tutorials 1.8.2 documentation

also (re)read the first half of the tensor shape tutorial:
Tensor shapes in Pyro — Pyro Tutorials 1.8.2 documentation

in particular you need to understand something about the event dimension of distributions.

if these don’t help can you try to ask the most narrow question possible?

Thanks for the links, both of those have been helpful but I couldn’t quite work out how to apply it to my specific case (and also the use of unsqueeze doesn’t seem to be explained anywhere).

In trying to think of a clearer way of asking this question, I think I’ve worked out what I’m doing wrong, but will explain it here:

My model is count_i ~ P(j) * Dirichlet-multinomial(N_i, concentration_ij)
where i is over positions, j is over possible ancestors, P(j) is the probability of j being the ancestor, N_i is the total counts at position i, concentration is the expected proportions of counts at position i for ancestor j. To make this more concrete, let’s say I have five positions, two possible ancestors, and three possible outcomes of each multinomial.

My data is then:

  • The observed count matrix obs_mat with shape (5, 3)
  • The total counts N with shape (5, )
  • Expected concentrations alpha with shape (5, 3, 2) (in the real model this is computed from other parameters, but let’s treat this as fixed).

I think I am correct in having the outer plate over the five positions i, but where I have been going wrong is having an inner plate over the two ancestors j. i.e. the P(j) part above doesn’t work in a plate as the dimensions of the multinomial don’t change. I believe this should be a jax loop instead.

N = np.array([50, 40, 50, 40, 45])  # 50 counts at position 0, 40 counts at position 2 etc
concentration = np.array([[0.5, 0.5], [0.7, 0.3]])
with numpyro.plate('position', 10): # 10 independent positions
     for j in range(2): # ancestors P(j) - NB I imagine better to make a jax loop
         numpyro.sample('obs', dist.DirichletMultinomial(concentration=alpha[:, :, j], total_count=N), obs=obs_mat)

I now need to work out how to weight these sample by P(j), previously I was using numpyro.handlers.scale, but need to do a bit more testing. As I think this is probably off the original topic I’ll open another thread that’s more targeted if I need help there.