Hi, fehiepsi, can you please help me out of this issue dist.Bernoulli recognize the right batch_size, but dist.Categorical won’t - Tutorials - Pyro Discussion Forum ?
same code for Bernoulli distribution would recognize the right batch_shape and event_shape, as:
x = pyro.sample("x", dist.Bernoulli(loc, validate_args=False).to_event(1), obs=xs)
b = dist.Bernoulli(loc, validate_args=False).to_event(1)
print(f"b batch_shape:{b.batch_shape}")
print(f"b event_shape:{b.event_shape}")
would output:
b batch_shape:torch.Size([10, 200])
b event_shape:torch.Size([784])
but
x = pyro.sample("x", dist.Categorical(loc, validate_args=False).to_event(1), obs=xs)
b = dist.Categorical(loc, validate_args=False).to_event(1)
print(f"b batch_shape:{b.batch_shape}")
print(f"b event_shape:{b.event_shape}")
output:
b batch_shape:torch.Size([10])
b event_shape:torch.Size([200])
I dig a littler futher into the torch.dist.Categorical implementation, and maybe you could help me out in that topic.