# Shape mismatch with multiple indexing steps

I have a model whose shapes I don’t quite understand. I’ve read the tutorial on shapes, but the right way to handle dimensions for my model continues to elude me.

``````num_variants = 3

@config_enumerate
def model(N=10):

ps = sample('ps', dist.Uniform(zeros(num_variants), ones(num_variants)).to_event()) # torch.Size([3])
labels = sample('labels', dist.Bernoulli(ps).to_event()) # torch.Size([3])
penetrances = param('penetrance', torch.tensor([.5, .5])) # torch.Size([2])

with plate('individuals', N, dim=-1):
variant = sample('variant', dist.Categorical(ones(num_variants)),
infer={'enumerate': 'parallel'}).long() # torch.Size([10])
label = labels[variant] # torch.Size([10])
penetrance = penetrances[label.long()] # torch.Size([2])
affected = sample('affected', dist.Bernoulli(penetrance)) # torch.Size([10])

# print shapes
print(f"{ps.shape=}")
print(f"{labels.shape=}")
print(f"{penetrances.shape=}")
print(f"{variant.shape=}")
print(f"{label.shape=}")
print(f"{penetrance.shape=}")
print(f"{affected.shape=}")
``````

The idea is that there can be multiple individuals (samples) with a given variant. I want to do inference about `ps`, the Bernoulli probability behind the variant label.

When I sample the model like this:

``````with plate('samples', 5, dim=-2):
model()
``````

I see

``````ValueError: Shape mismatch inside plate('individuals') at site affected dim -1, 1 vs 3
``````

Changing the plate dim in the model to -2 and the sampling plate to -3 prevents the ValueError, but the sites all have ton of singleton dimensions and the multiple dimensions equal to the sample size instead of just one:

``````ps.shape=torch.Size([5, 1, 1, 3])
labels.shape=torch.Size([5, 1, 1, 5, 1, 1, 3])
penetrances.shape=torch.Size([2])
variant.shape=torch.Size([5, 1, 1])
label.shape=torch.Size([5, 1, 1, 1, 1, 5, 1, 1, 3])
penetrance.shape=torch.Size([5, 1, 1, 1, 1, 5, 1, 1, 3])
affected.shape=torch.Size([5, 1, 1, 1, 1, 5, 5, 1, 3])
``````

What should I change to fix the ValueError without creating all these extra duplicate dimensions?

have you also read the enumeration tutorial?

I have. I can give it a reread, though.

edit: I’ve also tried with Vindex on the lines that start with `label` and `penetrance`, as recommended in the tutorials, but that didn’t change the shapes at all.