dist.Bernoulli recognize the right batch_size, but dist.Categorical won't

Hi all, I’m back again,
Sorry bothering!
the original xs sampling code in model() of ss_vae_M2.py is:

loc = self.decoder.forward([zs, ys])
x = pyro.sample("x", dist.Bernoulli(loc, validate_args=False).to_event(1), obs=xs)

where the shape of loc is [10,200,784], (10 for parallel enumeration, 200 original batch_size, 784 digits) and the dist.Bernoulli is able to realize the batch_size is [10,200] and event_size is [784]:

b = dist.Bernoulli(loc, validate_args=False).to_event(1)
print(f"b batch_shape:{b.batch_shape}")
print(f"b event_shape:{b.event_shape}")

and we have:

b batch_shape:torch.Size([10, 200])
b event_shape:torch.Size([784])

However, when I replace Bernoulli with Categorical, and all the other codes and data are left the same, I got:

loc = self.decoder.forward([zs, ys])
x = pyro.sample("x", dist.Categorical(loc, validate_args=False).to_event(1), obs=xs)
b = dist.Categorical(loc, validate_args=False).to_event(1)
print(f"b batch_shape:{b.batch_shape}")
 print(f"b event_shape:{b.event_shape}")

but the result become:

b batch_shape:torch.Size([10])
b event_shape:torch.Size([200])

Will someone help me figure this out?

I see the original implementaion of Categorical in torch.dist.Categorical, and found it would recoginize the right shapes if “.to_event(1)” not applied:

b = dist.Categorical(loc, validate_args=False)
print(f"b _batch_shape:{b._batch_shape}")
print(f"b _num_events:{b._num_events}")

output:

b _batch_shape:torch.Size([10, 200])
b _num_events:784

Note that the _batch_shape and ‘_num_events’ are original implementations in torch.dist.Categorical, while if we call same variables in pyro wrapper:

b = dist.Categorical(loc, validate_args=False)
print(f"b batch_shape:{b.batch_shape}")
print(f"b event_shape:{b.event_shape}")

the results would be:

b batch_shape:torch.Size([10, 200])
b event_shape:torch.Size([])

which is different from those in Bernoulli distribution:

b = dist.Categorical(loc, validate_args=False).to_event(1)
print(f"b batch_shape:{b.batch_shape}")
print(f"b event_shape:{b.event_shape}")

output:

b batch_shape:torch.Size([10, 200])
b event_shape:torch.Size([784])