Hi all, I’m back again,
Sorry bothering!
the original xs sampling code in model() of ss_vae_M2.py is:
loc = self.decoder.forward([zs, ys])
x = pyro.sample("x", dist.Bernoulli(loc, validate_args=False).to_event(1), obs=xs)
where the shape of loc is [10,200,784], (10 for parallel enumeration, 200 original batch_size, 784 digits) and the dist.Bernoulli is able to realize the batch_size is [10,200] and event_size is [784]:
b = dist.Bernoulli(loc, validate_args=False).to_event(1)
print(f"b batch_shape:{b.batch_shape}")
print(f"b event_shape:{b.event_shape}")
and we have:
b batch_shape:torch.Size([10, 200])
b event_shape:torch.Size([784])
However, when I replace Bernoulli with Categorical, and all the other codes and data are left the same, I got:
loc = self.decoder.forward([zs, ys])
x = pyro.sample("x", dist.Categorical(loc, validate_args=False).to_event(1), obs=xs)
b = dist.Categorical(loc, validate_args=False).to_event(1)
print(f"b batch_shape:{b.batch_shape}")
print(f"b event_shape:{b.event_shape}")
but the result become:
b batch_shape:torch.Size([10])
b event_shape:torch.Size([200])
Will someone help me figure this out?