Shape mismatch with multiple indexing steps

I have a model whose shapes I don’t quite understand. I’ve read the tutorial on shapes, but the right way to handle dimensions for my model continues to elude me.

num_variants = 3

@config_enumerate
def model(N=10):
    
    ps = sample('ps', dist.Uniform(zeros(num_variants), ones(num_variants)).to_event()) # torch.Size([3]) 
    labels = sample('labels', dist.Bernoulli(ps).to_event()) # torch.Size([3]) 
    penetrances = param('penetrance', torch.tensor([.5, .5])) # torch.Size([2])
    
    with plate('individuals', N, dim=-1):
        variant = sample('variant', dist.Categorical(ones(num_variants)), 
                            infer={'enumerate': 'parallel'}).long() # torch.Size([10]) 
        label = labels[variant] # torch.Size([10])
        penetrance = penetrances[label.long()] # torch.Size([2])
        affected = sample('affected', dist.Bernoulli(penetrance)) # torch.Size([10])
    
    # print shapes
    print(f"{ps.shape=}")
    print(f"{labels.shape=}")
    print(f"{penetrances.shape=}")
    print(f"{variant.shape=}")
    print(f"{label.shape=}")
    print(f"{penetrance.shape=}")
    print(f"{affected.shape=}")

The idea is that there can be multiple individuals (samples) with a given variant. I want to do inference about ps, the Bernoulli probability behind the variant label.

When I sample the model like this:

with plate('samples', 5, dim=-2):
    model()

I see

ValueError: Shape mismatch inside plate('individuals') at site affected dim -1, 1 vs 3

Changing the plate dim in the model to -2 and the sampling plate to -3 prevents the ValueError, but the sites all have ton of singleton dimensions and the multiple dimensions equal to the sample size instead of just one:

ps.shape=torch.Size([5, 1, 1, 3])
labels.shape=torch.Size([5, 1, 1, 5, 1, 1, 3])
penetrances.shape=torch.Size([2])
variant.shape=torch.Size([5, 1, 1])
label.shape=torch.Size([5, 1, 1, 1, 1, 5, 1, 1, 3])
penetrance.shape=torch.Size([5, 1, 1, 1, 1, 5, 1, 1, 3])
affected.shape=torch.Size([5, 1, 1, 1, 1, 5, 5, 1, 3])

What should I change to fix the ValueError without creating all these extra duplicate dimensions?

have you also read the enumeration tutorial?

I have. I can give it a reread, though.

edit: I’ve also tried with Vindex on the lines that start with label and penetrance, as recommended in the tutorials, but that didn’t change the shapes at all.