Multivariate Poisson Mixture Model with auto marginalisation

Hi, I am new to pyro and would like to create a Multivariate Poisson Mixture Model using pyro.
Based on the GMM tutorial, I attempted this:

import os
from collections import defaultdict
import torch
import numpy as np
import scipy.stats
from torch.distributions import constraints
import matplotlib.pyplot as plt
%matplotlib inline
import pyro
import pyro.distributions as dist
from pyro import poutine
from pyro.infer.autoguide import AutoDelta
from pyro.optim import Adam
from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate, infer_discrete, Trace_ELBO
from pyro.infer import Predictive
from pyro.ops.indexing import Vindex
smoke_test = "CI" in os.environ
assert pyro.__version__.startswith('1.9.1')

K = 4 # Fixed number of components.
n_obs = 100
n_features = 5
real_assignment = torch.randint(high=K, size=(100,1))
real_means = torch.tensor([1,5,10,20])
data = torch.tensor(scipy.stats.poisson.rvs(mu=real_means[real_assignment], size=(n_obs,n_features))).float()

@config_enumerate
def model(data):
    weights = pyro.sample("weights", dist.Dirichlet(0.5 * torch.ones(K)))
    with pyro.plate("components", K):
        means = pyro.sample("means", dist.Uniform(1.0, 20.0))

    with pyro.plate("data", len(data)):
        assignment = pyro.sample("assignment", dist.Categorical(weights))
        expanded_means = means[assignment][:,None].expand(-1,data.shape[1])
        pyro.sample("obs", dist.Poisson(expanded_means).to_event(1), obs=data)

optim = pyro.optim.Adam({"lr": 0.1, "betas": [0.8, 0.99]})
elbo = TraceEnum_ELBO(max_plate_nesting=1)

global_guide = AutoDelta(
        poutine.block(model, expose=["weights", "means"]))
svi = SVI(model, global_guide, optim, loss=elbo)

losses = []
for i in range(200):
    loss = svi.step(data)
    losses.append(loss)
    print("." if i % 100 else "\n", end="")

print(global_guide(data))

In this model on event is 5 draws of the same Poisson distribution.
I want the latent categorical variables to be marginalised out and if I understood correctly this can be done using TraceEnum_ELBO.

Get this error when running the code :

RuntimeError: expand(torch.FloatTensor{[4, 1, 1]}, size=[-1, 5]): the number of sizes provided (2) must be greater or equal to the number of dimensions in the tensor (3)
  Trace Shapes:          
   Param Sites:          
  Sample Sites:          
   weights dist       | 4
          value       | 4
components dist       |  
          value     4 |  
     means dist     4 |  
          value     4 |  
      data dist       |  
          value   100 |  
assignment dist   100 |  
          value 4   1 | 

Can someone help me please ?

The error shows something wrong with the expanded_means computation under enumeration. When enumerated, new dimensions are added to the variables. You can print out the shapes of those variables to make sure that operators like expand work.