 Sampling multiple values from a MultivariateNormal

I’m working on implementing an inference algorithm for a gmm. Right now, I’m having trouble with the variational parameters for my cluster parameters in my guide. I want tau to have shape (TxM) so that I can optimize each cluster mean accordingly, where T=5 is the number of clusters and M=2 is the number of features for observations. Here is my code:

def model(data):
# global variables
weights = torch.ones(T) / T
for i in pyro.plate('components', T):
locs = pyro.sample('locs_{}'.format(i), MultivariateNormal(torch.zeros(M), torch.eye(M)))

# local variables
assignments = pyro.sample('assignments', Categorical(weights).expand([N]))
for i in pyro.plate('data', N):
pyro.sample('obs_{}'.format(i), MultivariateNormal(locs[assignments[i]], torch.eye(M)), obs=data[i])torch.eye(M)), obs=data)

def guide(data):
# amortize using MLP
pyro.module('alpha_mlp', alpha_mlp)

# sample mixture components mu
tau = pyro.param('tau', MultivariateNormal(torch.zeros(M), torch.eye(M)).expand([T]))
for i in pyro.plate('components', T):
locs = pyro.sample('locs_{}'.format(i), MultivariateNormal(tau[i], torch.eye(M)))

# sample cluster assignments
alpha = alpha_mlp(data.double())
weights = pyro.param('weights', Dirichlet(alpha))
assignments = pyro.sample('assignments', Categorical(weights).expand([N]))

Specifically, it’s these lines of code that I’m having trouble with. I keep getting a tensor of length 5, but I expected to get a (5x2) tensor since I’m sampling from a MultivariateNormal with dimension 2?

# sample mixture components mu
tau = pyro.param('tau', MultivariateNormal(torch.zeros(M), torch.eye(M)).expand([T]))
for i in pyro.plate('components', T):
locs = pyro.sample('locs_{}'.format(i), MultivariateNormal(tau[i], torch.eye(M)))

Is there a specific need for the for loop version of pyro.plate? I’d recommend using the with pyro.plate idiom, might help clear up dimensionality issues.

Also it’d be helpful to include a MWE with dummy data so others can run the code too.