Hi all - I’m trying to construct a binomial matrix factorization where the factors are beta distributed and the factor loadings are Dirichlet distributed. I was struggling a bit to get the correct event vs batch dims. The following runs but I’d love another pair of eyes to tell if this makes sense - for one thing I feel like I shouldn’t need the two data
plates. Thanks!
def model(y, total_counts, K):
N,P = y.shape
psi_dist = dist.Beta(1.,1.).expand([K,P]).to_event(2)
psi = pyro.sample("psi", psi_dist)
with pyro.plate('data1', N):
assign_dist = dist.Dirichlet(torch.ones(K))
assign = pyro.sample("assign", assign_dist)
pred = assign @ psi
with pyro.plate('dims', P):
with pyro.plate('data', N):
obs = pyro.sample('obs', dist.Binomial(total_counts, pred), obs = y)
Full example in Colab.