Sampling multiple values from a MultivariateNormal

I’m working on implementing an inference algorithm for a gmm. Right now, I’m having trouble with the variational parameters for my cluster parameters in my guide. I want tau to have shape (TxM) so that I can optimize each cluster mean accordingly, where T=5 is the number of clusters and M=2 is the number of features for observations. Here is my code:

def model(data):
    # global variables 
    weights = torch.ones(T) / T
    for i in pyro.plate('components', T):
        locs = pyro.sample('locs_{}'.format(i), MultivariateNormal(torch.zeros(M), torch.eye(M)))

    # local variables
    assignments = pyro.sample('assignments', Categorical(weights).expand([N]))
    for i in pyro.plate('data', N):
        pyro.sample('obs_{}'.format(i), MultivariateNormal(locs[assignments[i]], torch.eye(M)), obs=data[i])torch.eye(M)), obs=data)
        
def guide(data):
    # amortize using MLP
    pyro.module('alpha_mlp', alpha_mlp)
    
    # sample mixture components mu
    tau = pyro.param('tau', MultivariateNormal(torch.zeros(M), torch.eye(M)).expand([T]))
    for i in pyro.plate('components', T):
        locs = pyro.sample('locs_{}'.format(i), MultivariateNormal(tau[i], torch.eye(M)))
    
    # sample cluster assignments
    alpha = alpha_mlp(data.double()) 
    weights = pyro.param('weights', Dirichlet(alpha)) 
    assignments = pyro.sample('assignments', Categorical(weights).expand([N]))

Specifically, it’s these lines of code that I’m having trouble with. I keep getting a tensor of length 5, but I expected to get a (5x2) tensor since I’m sampling from a MultivariateNormal with dimension 2?

    # sample mixture components mu
    tau = pyro.param('tau', MultivariateNormal(torch.zeros(M), torch.eye(M)).expand([T]))
    for i in pyro.plate('components', T):
        locs = pyro.sample('locs_{}'.format(i), MultivariateNormal(tau[i], torch.eye(M)))

Is there a specific need for the for loop version of pyro.plate? I’d recommend using the with pyro.plate idiom, might help clear up dimensionality issues.

Also it’d be helpful to include a MWE with dummy data so others can run the code too.