How pyro categorical knows the corresponding position for y


pyro.sample(categorical(logits), obs=a)

for example in MNIST classification, the obs is from 1 to 9
and a logics is a [batch * 10] tensor, how the categorical distribution which position is 1 and others are 0 if we do not convert to one hot label.


@yikuanlee_pyro Samples of Categorical distribution are numbers from 0 to 9 and that is the design choice of Categorical distribution. You can find its definition here.