I tried to use categorical distributions in numpyro,i found that
dist.Categorical(probs=jnp.array([0.7,0.3,0.3,0.3])).sample() only produced 0 or 1, and never 2 or 3,
which means that it will accumulate probabilities until it reaches 1, and all following probability entries are omitted without any error.
I believe that normalizing these probabilities like pyro does might be better, because it frees the user from checking whether the probabilitied sum up to 1. For example, when the user wants to use jnp.ones to set the probs quickly. Or, if there are some problems implementing normalization, it will be better to give error message when the probs don’t sum up to 1, telling the user to check the probs parameter.