I have essentially the same question as was asked here. However, that thread didn’t end up answering the question. I am implementing a variant of a mixture model, where each latent discrete variable has multiple observations. This leads me to use nested plates, like so:
with pyro.plate('groups', N_groups):
z = pyro.sample('z', dist.Categorical(theta),)
with pyro.plate('samples', N_samples_per_group):
x = pyro.sample('x', dist.MultivariateNormal(mus[z],
covariance_matrix=torch.eye(latent_dim)))
However, in actuality I have differing numbers of continuous observations (x) for each discrete variable (z). I could perhaps use a for-loop, but this will be very slow. Are there any alternatives?