Thank you fritzo, fehiepsi helped me figure this out in another topic https://forum.pyro.ai/t/question-about-batch-size-in-the-semi-supervised-vae-demo/4891/9
Following your instructions, I made changes in following codes:
loc = self.decoder_x.forward(thetas)
loc = torch.reshape(loc, (-1, batch_size, num_instances, num_categories))
xs_hat = pyro.sample("x", dist.Categorical(logits=loc, validate_args=False).to_event(1), obs=xs)
from which I
- amplified the output of the
self.decoder_x
byoriginal_dim
* num_categories, - reshape the long dim vector into matrix [-1, batch_size, num_instances, num_catgories] and
- fed the mat into
dist.Categorical
And may I ask for you help if this is a proper way to work this out?