Suggestion on probs parameter of categorical dist in numpyro

I tried to use categorical distributions in numpyro,i found that
dist.Categorical(probs=jnp.array([0.7,0.3,0.3,0.3])).sample() only produced 0 or 1, and never 2 or 3,
which means that it will accumulate probabilities until it reaches 1, and all following probability entries are omitted without any error.

I believe that normalizing these probabilities like pyro does might be better, because it frees the user from checking whether the probabilitied sum up to 1. For example, when the user wants to use jnp.ones to set the probs quickly. Or, if there are some problems implementing normalization, it will be better to give error message when the probs don’t sum up to 1, telling the user to check the probs parameter.

Which version of NumPyro are you using? This might have been fixed in version 0.10.0 via PR #1419.

Thanks, i have updated to 0.10.0 and the problem has been solved.

Note that in numpyro, we typically don’t modify the input. It is better to make sure that probs is a simplex, otherwise when enabling validation, there might be some warnings/errors happen. In addition, if probs is not normalized, log_prob might nfot give the expected result. One way to make the probs treatment more Pyro friendly is to revise the current implementation of sample, log_prob methods that invoke self.probs - please make a feature request if needed.

Thanks!
I saw the error pomping out when enable_validation, i think your idea makes sense, it’s the users responsibility to make their distribution valid mathematically.
I guess it would be better to say something about this in the document of categorical distribution?

I have a question posted many days, the model is quite straight forward, would you please spend some time taking a look at it?
Sorry for asking too much questions~ I’m quite new to pyro/numpyro but my work is somewhat difficult for my current ability.

Definitely! Please make a feature/pull request to clarify this point.