Question about batch_size in the semi-supervised VAE demo

Hi, fehiepsi, can you please help me out of this issue dist.Bernoulli recognize the right batch_size, but dist.Categorical won’t - Tutorials - Pyro Discussion Forum ?

same code for Bernoulli distribution would recognize the right batch_shape and event_shape, as:

x = pyro.sample("x", dist.Bernoulli(loc, validate_args=False).to_event(1), obs=xs)
b = dist.Bernoulli(loc, validate_args=False).to_event(1)
print(f"b batch_shape:{b.batch_shape}")
print(f"b event_shape:{b.event_shape}")

would output:

b batch_shape:torch.Size([10, 200])
b event_shape:torch.Size([784])

but

x = pyro.sample("x", dist.Categorical(loc, validate_args=False).to_event(1), obs=xs)
b = dist.Categorical(loc, validate_args=False).to_event(1)
print(f"b batch_shape:{b.batch_shape}")
print(f"b event_shape:{b.event_shape}")

output:

b batch_shape:torch.Size([10])
b event_shape:torch.Size([200])

I dig a littler futher into the torch.dist.Categorical implementation, and maybe you could help me out in that topic.