Problem implementing a mixture model using enumeration

Thank you for the quick reply! I realize that in creating my toy model above, I made a mistake in communicating my problem. The dimensions of plate1 should actually match that of variable A, which is not an event dimension. Both should be of dimension size 5. I’ve corrected this in the original post.

With that, my original problem still is unsolved, as I don’t think I want to send A to_event. I am able to get the code running for the toy by using to_event(1) and then indexing later on, but I don’t think this is necessarily good practice, or whether it will translate to complex models?

@config_enumerate
def full_model(data):
    plate1 = pyro.plate("plate1", 5, dim=-1)
    plate2 = pyro.plate("plate2", data.shape[-2], dim=-2)
    plate3 = pyro.plate("plate3", data.shape[-3], dim=-3)
    
    with plate3:
        A = pyro.sample("A", dist.Normal(torch.tensor([0., 0.2, 0.4, 0.6, 0.8]), 
                                         torch.tensor([1., 0.8, 0.6, 0.4, 0.2])).to_event(1))
        
        probs = pyro.sample('probs', dist.Dirichlet(torch.tensor([30., 70.])))
        assignment = pyro.sample('assignment', dist.Categorical(probs), infer={"enumerate": "parallel"})           
    
    with plate2:
        B = pyro.sample("B", dist.Uniform(0., 1.))
    
    prediction = torch.where(assignment==1, 
                             (A[:, :, 0, :]*B).sum(-1).unsqueeze(-1), 
                             (A[:, :, 0, :]*-B).sum(-1).unsqueeze(-1))
    
    with plate3, plate2:
        X = pyro.sample("X", dist.Gamma(torch.exp(prediction),
                                        torch.tensor(1.)), obs=data)

Thank you for your help and patience!