I have a model whose shapes I don’t quite understand. I’ve read the tutorial on shapes, but the right way to handle dimensions for my model continues to elude me.
num_variants = 3
@config_enumerate
def model(N=10):
ps = sample('ps', dist.Uniform(zeros(num_variants), ones(num_variants)).to_event()) # torch.Size([3])
labels = sample('labels', dist.Bernoulli(ps).to_event()) # torch.Size([3])
penetrances = param('penetrance', torch.tensor([.5, .5])) # torch.Size([2])
with plate('individuals', N, dim=-1):
variant = sample('variant', dist.Categorical(ones(num_variants)),
infer={'enumerate': 'parallel'}).long() # torch.Size([10])
label = labels[variant] # torch.Size([10])
penetrance = penetrances[label.long()] # torch.Size([2])
affected = sample('affected', dist.Bernoulli(penetrance)) # torch.Size([10])
# print shapes
print(f"{ps.shape=}")
print(f"{labels.shape=}")
print(f"{penetrances.shape=}")
print(f"{variant.shape=}")
print(f"{label.shape=}")
print(f"{penetrance.shape=}")
print(f"{affected.shape=}")
The idea is that there can be multiple individuals (samples) with a given variant. I want to do inference about ps
, the Bernoulli probability behind the variant label.
When I sample the model like this:
with plate('samples', 5, dim=-2):
model()
I see
ValueError: Shape mismatch inside plate('individuals') at site affected dim -1, 1 vs 3
Changing the plate dim in the model to -2 and the sampling plate to -3 prevents the ValueError, but the sites all have ton of singleton dimensions and the multiple dimensions equal to the sample size instead of just one:
ps.shape=torch.Size([5, 1, 1, 3])
labels.shape=torch.Size([5, 1, 1, 5, 1, 1, 3])
penetrances.shape=torch.Size([2])
variant.shape=torch.Size([5, 1, 1])
label.shape=torch.Size([5, 1, 1, 1, 1, 5, 1, 1, 3])
penetrance.shape=torch.Size([5, 1, 1, 1, 1, 5, 1, 1, 3])
affected.shape=torch.Size([5, 1, 1, 1, 1, 5, 5, 1, 3])
What should I change to fix the ValueError without creating all these extra duplicate dimensions?