Hi all - I’m trying to construct a binomial matrix factorization where the factors are beta distributed and the factor loadings are Dirichlet distributed. I was struggling a bit to get the correct event vs batch dims. The following runs but I’d love another pair of eyes to tell if this makes sense - for one thing I feel like I shouldn’t need the two
data plates. Thanks!
def model(y, total_counts, K): N,P = y.shape psi_dist = dist.Beta(1.,1.).expand([K,P]).to_event(2) psi = pyro.sample("psi", psi_dist) with pyro.plate('data1', N): assign_dist = dist.Dirichlet(torch.ones(K)) assign = pyro.sample("assign", assign_dist) pred = assign @ psi with pyro.plate('dims', P): with pyro.plate('data', N): obs = pyro.sample('obs', dist.Binomial(total_counts, pred), obs = y)
Full example in Colab.