MAP prediction with enumeration

Hi! I’m currently constructing a gaussian mixture model, whose component parameter of covariance is sampled from another mixture. To realize this idea, I used two sets of discrete variables for indexing, and used enumeration for inference. The svi works well for inference, and convergence is reached.

However, I run into errors when I’m trying to perform classifiaction:

ValueError: Error while packing tensors at site 'inv_Sigma':
  Invalid tensor shape.
  Allowed dims: -1
  Actual shape: (7, 19)

        u dist   19 |    
         value 7  1 |    
      log_prob 7 19 |    
inv_Sigma dist 7 19 | 2 2
         value   19 | 2 2
      log_prob 7 19 |    

where 7 and 19 are the number of parameter mixture components and sample mixture components respectively.
It seems to me that this shape at dim=-2 of shape 7, is a result of enumeration. However, what I’m trying to do is simply getting a MAP estimation or prediction. So I was wondering: what’s wrong with my classification function?

guide_trace = poutine.trace(global_guide).get_trace(data)  # record the globals
trained_model = poutine.replay(model, trace=guide_trace)  # replay the globals
def classifier(data, temperature=0): #random assignment if temperature=1
    inferred_model = infer_discrete(trained_model, temperature=temperature,
                                    first_available_dim=-2)
    trace = poutine.trace(inferred_model).get_trace(data) #error raised here
    return trace.nodes["z"]["value"]

Here is my model, if it provides any information:

# DPMM Model in pyro
T_comp = 7
T_mix = 19
alpha_0 = 0.1
gamma_0 = 0.1

def mix_weights(beta):
    beta1m_cumprod = (1 - beta).cumprod(-1)
    return F.pad(beta, (0, 1), value=1) * F.pad(beta1m_cumprod, (1, 0), value=1)

@config_enumerate 
def refresh(data=None):
    alpha = pyro.param("alpha", torch.tensor([alpha_0]))
    gamma = pyro.param("gamma", torch.tensor([gamma_0]))
    var = pyro.param("var", lambda:Chi2(torch.ones(d)).to_event(1).sample([T_mix]), constraint=constraints.positive)
    
    with pyro.plate("component_sticks", T_comp-1):
        beta_comp = pyro.sample("beta_comp", Beta(1, gamma))

    with pyro.plate("component", T_comp):
        # component of prior for covariance
        nu = pyro.sample("nu", Uniform(d+2, 2*d+2))
        theta = pyro.sample("theta", Chi2(df=torch.ones(d)*(d+1)).to_event(1))
        omega = pyro.sample('omega', LKJCholesky(d, concentration=1))
        Omega = torch.bmm(theta.sqrt().diag_embed(), omega)

    with pyro.plate("mixture_sticks", T_mix-1):
        beta_mix = pyro.sample("beta_mix", Beta(1, alpha))

    with pyro.plate("mixture", T_mix):
        u = pyro.sample("u", Categorical(mix_weights(beta_comp)), infer={'enumerate': 'parallel'})
        inv_Sigma = pyro.sample("inv_Sigma", Wishart(df=nu[u], scale_tril=Omega[u]))
        mu = pyro.sample("mu", MultivariateNormal(torch.zeros(d), scale_tril=torch.bmm(var.sqrt().diag_embed(), torch.eye(d).repeat(T_mix,1,1))))

    with pyro.plate("data", N):
        z = pyro.sample("z", Categorical(mix_weights(beta_mix)), infer={'enumerate': 'parallel'})
        pyro.sample("obs", MultivariateNormal(mu[z], precision_matrix=inv_Sigma[z]), obs=data)

I was wondering whether this is related with parallel enumeration? As one discrete variable would influence the downstream distributions, is it ok to perform parallel enumeration in this model?