Question about batch_size in the semi-supervised VAE demo

Because the probs loc has shape (200, 784), the Categorical batch shape will be 200 and event shape is (). If we are given a fair dice, its prob will has shape 6. Each time you throw it, you will get a number (not 6 numbers). I think your decoder needs to return some thing with shape (200, 784, dim) where dim is the total number of possible values of xs, like dim=(xs.max() + 1). Btw, typically MLP returns a tensor in real domain, if so it is better to use loc as logits, rather than probs.