This is my VAE model and guide, with some unnecessary details omitted for simplicity. The model first generates “ys” from a uniform OneHotCategorical, and the shape of “ys” is (batch_size, 3). Then, the “ys” are put through an embedding layer to obtain “ys_emb” with shape (batch_size, 128). “ys_emb” are then put though an MLP to obtain the probability distributions of “hs”, from which “hs” are sampled with OneHotCategorical. Then, “hs_emb” are obtained in a similar way as “ys_emb”. Finally, “hs_emb” are put through an MLP to obtain the distributions of “xs” from which “xs” are sampled with OneHotCategorical. What is a little bit special here is that the distribution of “xs” (i.e. xs_prob) is three dimensional (batch_size, 25, 5). That is, for every independent sample, it consists of 25 independent categorical random variables, each with 5 possible outcomes.

```
class my_model(nn.Module):
...
@config_enumerate(default="parallel")
def model(self, xs, ys=None, hs=None):
batch_size = xs.size(0)
with pyro.plate("data"):
alpha_prior = torch.ones(batch_size, self.y_size)/(1.0*self.y_size)
ys = pyro.sample("y", dist.OneHotCategorical(alpha_prior).to_event(1), obs=ys)
ys_emb = torch.mm(ys, self.emb_y)
hs_prob = self.decoder_h.forward(ys_emb)
hs_prob = hs_prob.exp()/hs_prob.exp().sum(axis=1, keepdims=True)
hs = pyro.sample("h", dist.OneHotCategorical(hs_prob).to_event(1), obs=hs)
hs_emb = torch.mm(hs, self.emb_h)
xs_prob = self.decoder_x.forward([hs_emb])
xs_prob = xs_prob.exp() / torch.mm(xs_prob.exp(), self.operator_x) # softmax with every 5 numbers
if xs!=None:
xs = pyro.sample("xs", dist.OneHotCategorical(xs_prob.reshape(-1,25,5)).to_event(2), obs=xs.reshape(-1,25,5))
else:
xs = pyro.sample("xs", dist.OneHotCategorical(xs_prob.reshape(-1,25,5)).to_event(2), obs=None)
@config_enumerate(default="parallel")
def guide(self, xs, ys=None, hs=None):
with pyro.plate("data"):
if hs is None:
hs_prob = self.encoder_h.forward(xs)
hs_prob = hs_prob.exp() / hs_prob.exp().sum(axis=1, keepdims=True)
hs = pyro.sample("h", dist.OneHotCategorical(hs_prob).to_event(1))
if ys is None:
hs_emb = self.emb_h.forward(hs)
alpha = self.encoder_y.forward([xs,hs_emb])
alpha = alpha.exp() / alpha.exp().sum(axis=1,keepdims=True)
ys = pyro.sample("y", dist.OneHotCategorical(alpha).to_event(1))
```

I utilized the following lower-level implementation of TraceEnum_ELBO in order to include a penalty term into the objective. Note that the “hs” are treated as latent variables.

```
model = my_model(...)
guide = config_enumerate(model.guide, expand=True)
loss_fn = lambda model, guide: TraceEnum_ELBO().differentiable_loss(model,guide,xs,ys,None)
with pyro.poutine.trace(param_only=True) as param_capture:
loss = loss_fn(model.model,guide)
params = set(site["value"].unconstrained() for site in param_capture.trace.nodes.values())
optimizer_model = torch.optim.Adam(params)
for i in range(num_iters):
loss = loss_fn(model.model,guide) + penalty
loss.backward()
optimizer_model.step()
optimizer_model.zero_grad()
```

However, I encountered the following warning:

"TraceEnum_ELBO found no sample sites configured for enumeration. "

I think this means that the enumeration is not implemented as expected. (probably equivalent to Trace_ELBO?) Following the suggestion I found in this topic: GMM example in documentation gives warning about from TraceEnum_ELBO, I added `infer={"enumerate": "parallel"}`

in the sample statement of “hs” in both model and guide, and had the following error:

“NotImplementedError: Enumeration over cartesian product is not implemented”

Does anybody know how to solve this problem? Sorry that the question is a bit long. Thank you so much for your time!

Woody