dist.Bernoulli recognize the right batch_size, but dist.Categorical won't

Thank you fritzo, fehiepsi helped me figure this out in another topic https://forum.pyro.ai/t/question-about-batch-size-in-the-semi-supervised-vae-demo/4891/9

Following your instructions, I made changes in following codes:

loc = self.decoder_x.forward(thetas)
loc = torch.reshape(loc, (-1, batch_size, num_instances, num_categories))
xs_hat = pyro.sample("x", dist.Categorical(logits=loc, validate_args=False).to_event(1), obs=xs)

from which I

  1. amplified the output of the self.decoder_x by original_dim * num_categories,
  2. reshape the long dim vector into matrix [-1, batch_size, num_instances, num_catgories] and
  3. fed the mat into dist.Categorical

And may I ask for you help if this is a proper way to work this out?