Mixture model with multidimensional weights prior

Hello,

I am trying to implement a mixture model in pyro with categorical distribution weights that have more than one dimension. This is because I have some prior knowledge about the categories of some (but not all) of my observations.

In this example my data has 80 observations across two dimensions with size 10 and 2, so:

x_data.shape = (80,10,2)

Furthermore, there are 4 categories. So in the simple case:

self.weights.shape = torch.Size([1, 1, 1, 4])

and in the more complex case with separate weights for each observation:

self.weights.shape = torch.Size([80, 1, 1, 4])

My likelihood function is:

with obs_plate:
    assignment = pyro.sample('assignment', dist.Categorical(probs = self.weights))
    rate = alpha / mu[torch.arange(self.n_obs),:,:,assignment[:,0,0,0]]
    pyro.sample("data_target", 
                dist.GammaPoisson(concentration= alpha, rate= rate), obs=x_data)

If I do not have separate weight priors for each observation, all variables seem to have the correct dimensions when I train the model and print out the shape of each variable:

self.weights.shape = torch.Size([1, 1, 1, 4])
assignment.shape = torch.Size([1, 1, 1, 4])
alpha.shape = torch.Size([10, 1])
rate.shape = torch.Size([80, 10, 2])
x_data.shape = torch.Size([80, 10, 2])

However, when I do have separate weights priors for each observations, I think the dimensions of the assignment variable are wrong:

self.weights.shape = torch.Size([80, 1, 1, 4])
assignment.shape = torch.Size([80, 1, 1, 4])
alpha.shape = torch.Size([10, 1])
rate.shape = torch.Size([80, 10, 2])
x_data.shape = torch.Size([80, 10, 2])

The assignment variable is inferred via enumeration during training, so from what I understand it should still just have shape (1,1,1,4) and not (80,1,1,4).
Memory requirements also increase drastically in this case, so that I cannot run this example on a large dataset. So I am stuck at this point and it would be really great if you can help me to get this complex model to work.

Best wishes,

Alexander