Model and guide shapes disagree at site ‘z_2’: torch.Size([2, 2]) vs torch.Size([2])
Anyone has the clue, why the shapes disagree at some point?
Here is the z_t sample site in the model: z_loc here is a torch tensor with shape [2,5], 2 is the batch size and 5 means there are 5 possible values.
the z_loc in model and the one in guide are different.
z_t = pyro.sample("z_%d" % t, dist.Categorical(logits=z_loc))
and the one in the guide:
z_t = pyro.sample("z_%d" % t, dist.Categorical(logits=z_loc))
I printed out the samples sampled from the distribution, but I did not see the shape is different.
guide z_1 tensor([2, 4])
guide z_2 tensor([4, 0])
guide z_3 tensor([1, 2])
model z_1 tensor([2, 4])
model z_2 tensor([4, 0])
model z_3 tensor([1, 2])
I am investigating probably the problem is from my own distribution parameters, z_loc.
Update:
I found the problem might be the design of model, in my model, I have nested plates.
with pyro.plate(...)
some code
with pyro.plate(...)
I assigned the dim=-2 in the first plate and -1 for the second one.