Hi there. This is more of a beginner question than anything else, as I’m still trying to wrap my brain around how NumPyro works. I am trying to find a way to replicate this simple blog post in NumPyro. If you don’t have a Medium subscription, I’ll describe the problem below:
Let’s say you have N dice, indexed by n \in [1, N]. Die n has n+2 sides. Someone picks a die at random, and starts rolling it, telling you the observations. What I want to do is define a posterior over the probability that the random die chosen is dice n after observations \mathcal{D}. Assume all dice are fair.
This is distinctly different than the problem of a single unfair dice and trying to infer the probability of observing a side, since it is always possible that any side has p>0. In this problem, say you observe a 5. There is no way it could be a 4 or 3-sided die. Those probabilities are 0. This is the particular piece I’m having trouble coding up. The update scheme using Bayes’s rule is quite easy, but I’m not sure how to put this into numpyro.
My starting point is just simply a Dirichlet distribution with equal probability that the random dice selected was die n.
N_DICE = 5 # sides [3, 4, 5, 6, 7]
def model(observations):
# observations might be something like [4, 5, 4, 6, 4]
theta = numpyro.sample('theta', dist.Dirichlet(np.ones(N_DICE)))
with numpyro.plate('observations', len(observations)):
# Unsure what goes here
Is a custom distribution or something required?
Thanks so much for the help!