Binomial matrix factorization model

Hi all - I’m trying to construct a binomial matrix factorization where the factors are beta distributed and the factor loadings are Dirichlet distributed. I was struggling a bit to get the correct event vs batch dims. The following runs but I’d love another pair of eyes to tell if this makes sense - for one thing I feel like I shouldn’t need the two data plates. Thanks!

def model(y, total_counts, K):
  N,P = y.shape
  psi_dist = dist.Beta(1.,1.).expand([K,P]).to_event(2)
  psi = pyro.sample("psi", psi_dist)

  with pyro.plate('data1', N):
    assign_dist = dist.Dirichlet(torch.ones(K))
    assign = pyro.sample("assign", assign_dist)

  pred = assign @ psi
  with pyro.plate('dims', P):
    with pyro.plate('data', N):
      obs = pyro.sample('obs', dist.Binomial(total_counts, pred), obs = y)

Full example in Colab.