I’m working on implementing an inference algorithm for a gmm. Right now, I’m having trouble with the variational parameters for my cluster parameters in my guide. I want tau to have shape (TxM) so that I can optimize each cluster mean accordingly, where T=5 is the number of clusters and M=2 is the number of features for observations. Here is my code:
def model(data):
# global variables
weights = torch.ones(T) / T
for i in pyro.plate('components', T):
locs = pyro.sample('locs_{}'.format(i), MultivariateNormal(torch.zeros(M), torch.eye(M)))
# local variables
assignments = pyro.sample('assignments', Categorical(weights).expand([N]))
for i in pyro.plate('data', N):
pyro.sample('obs_{}'.format(i), MultivariateNormal(locs[assignments[i]], torch.eye(M)), obs=data[i])torch.eye(M)), obs=data)
def guide(data):
# amortize using MLP
pyro.module('alpha_mlp', alpha_mlp)
# sample mixture components mu
tau = pyro.param('tau', MultivariateNormal(torch.zeros(M), torch.eye(M)).expand([T]))
for i in pyro.plate('components', T):
locs = pyro.sample('locs_{}'.format(i), MultivariateNormal(tau[i], torch.eye(M)))
# sample cluster assignments
alpha = alpha_mlp(data.double())
weights = pyro.param('weights', Dirichlet(alpha))
assignments = pyro.sample('assignments', Categorical(weights).expand([N]))
Specifically, it’s these lines of code that I’m having trouble with. I keep getting a tensor of length 5, but I expected to get a (5x2) tensor since I’m sampling from a MultivariateNormal with dimension 2?
# sample mixture components mu
tau = pyro.param('tau', MultivariateNormal(torch.zeros(M), torch.eye(M)).expand([T]))
for i in pyro.plate('components', T):
locs = pyro.sample('locs_{}'.format(i), MultivariateNormal(tau[i], torch.eye(M)))