A problem encountered with OnehotCategorical and TraceEnum_ELBO

This is my VAE model and guide, with some unnecessary details omitted for simplicity. The model first generates “ys” from a uniform OneHotCategorical, and the shape of “ys” is (batch_size, 3). Then, the “ys” are put through an embedding layer to obtain “ys_emb” with shape (batch_size, 128). “ys_emb” are then put though an MLP to obtain the probability distributions of “hs”, from which “hs” are sampled with OneHotCategorical. Then, “hs_emb” are obtained in a similar way as “ys_emb”. Finally, “hs_emb” are put through an MLP to obtain the distributions of “xs” from which “xs” are sampled with OneHotCategorical. What is a little bit special here is that the distribution of “xs” (i.e. xs_prob) is three dimensional (batch_size, 25, 5). That is, for every independent sample, it consists of 25 independent categorical random variables, each with 5 possible outcomes.

class my_model(nn.Module):
    ...

    @config_enumerate(default="parallel")
    def model(self, xs, ys=None, hs=None):
        
        batch_size = xs.size(0)
         
        with pyro.plate("data"):
            alpha_prior = torch.ones(batch_size, self.y_size)/(1.0*self.y_size)
            ys = pyro.sample("y", dist.OneHotCategorical(alpha_prior).to_event(1), obs=ys)
            ys_emb = torch.mm(ys, self.emb_y)
            
            hs_prob = self.decoder_h.forward(ys_emb)
            hs_prob = hs_prob.exp()/hs_prob.exp().sum(axis=1, keepdims=True)
            hs = pyro.sample("h", dist.OneHotCategorical(hs_prob).to_event(1), obs=hs)
            hs_emb = torch.mm(hs, self.emb_h)
           
            xs_prob = self.decoder_x.forward([hs_emb])
            xs_prob = xs_prob.exp() / torch.mm(xs_prob.exp(), self.operator_x) # softmax with every 5 numbers
            if xs!=None:
                xs = pyro.sample("xs", dist.OneHotCategorical(xs_prob.reshape(-1,25,5)).to_event(2), obs=xs.reshape(-1,25,5))
            else:
                xs = pyro.sample("xs", dist.OneHotCategorical(xs_prob.reshape(-1,25,5)).to_event(2), obs=None)
        
    @config_enumerate(default="parallel")
    def guide(self, xs, ys=None, hs=None):
        with pyro.plate("data"):
            if hs is None:
                hs_prob = self.encoder_h.forward(xs)
                hs_prob = hs_prob.exp() / hs_prob.exp().sum(axis=1, keepdims=True)
                hs = pyro.sample("h", dist.OneHotCategorical(hs_prob).to_event(1))
 
            if ys is None:
                hs_emb = self.emb_h.forward(hs)
                alpha = self.encoder_y.forward([xs,hs_emb])
                alpha = alpha.exp() / alpha.exp().sum(axis=1,keepdims=True)
                ys = pyro.sample("y", dist.OneHotCategorical(alpha).to_event(1)) 

I utilized the following lower-level implementation of TraceEnum_ELBO in order to include a penalty term into the objective. Note that the “hs” are treated as latent variables.

model = my_model(...)
guide = config_enumerate(model.guide, expand=True)
loss_fn = lambda model, guide: TraceEnum_ELBO().differentiable_loss(model,guide,xs,ys,None)
with pyro.poutine.trace(param_only=True) as param_capture:
    loss = loss_fn(model.model,guide)
params = set(site["value"].unconstrained() for site in param_capture.trace.nodes.values())
optimizer_model = torch.optim.Adam(params)
for i in range(num_iters):
    loss = loss_fn(model.model,guide) + penalty
    loss.backward()
    optimizer_model.step()
    optimizer_model.zero_grad()

However, I encountered the following warning:
"TraceEnum_ELBO found no sample sites configured for enumeration. "

I think this means that the enumeration is not implemented as expected. (probably equivalent to Trace_ELBO?) Following the suggestion I found in this topic: GMM example in documentation gives warning about from TraceEnum_ELBO, I added infer={"enumerate": "parallel"} in the sample statement of “hs” in both model and guide, and had the following error:

“NotImplementedError: Enumeration over cartesian product is not implemented”

Does anybody know how to solve this problem? Sorry that the question is a bit long. Thank you so much for your time!

Woody