I think you’ll need to change infer={"enumerate": "sequential"} to infer={"enumerate": "parallel"} in your model. You might find the GMM tutorial (Gaussian Mixture Model — Pyro Tutorials 1.8.6 documentation) useful for this example. Quoting from the tutorial:
When enumerating variables in the model, the variables must be enumerated in parallel and must not appear in the guide. Mathematically, guide-side enumeration simply reduces variance in a stochastic ELBO by enumerating all values, whereas model-side enumeration avoids an application of Jensen’s inequality by exactly marginalizing out a variable.
I don’t think the shapes would still align correctly, but something like the following should work (not tested for correctness), where you define a 2x6 “bias” vector and index into that directly:
with pyro.plate("select_dice", 2):
f = pyro.sample("bias", dist.Dirichlet(torch.ones(n_faces) * 10))
with pyro.plate("z_minibatch", len(data), batch_size) as ind:
choose = pyro.sample("choose", dist.Bernoulli(0.5), infer={"enumerate": "parallel"}).long()
pyro.sample("obs", dist.Categorical(f[choose]), obs=data[ind])