Correct way to code up a simple dice roll problem in numpyro

Hi there. This is more of a beginner question than anything else, as I’m still trying to wrap my brain around how NumPyro works. I am trying to find a way to replicate this simple blog post in NumPyro. If you don’t have a Medium subscription, I’ll describe the problem below:

Let’s say you have N dice, indexed by n \in [1, N]. Die n has n+2 sides. Someone picks a die at random, and starts rolling it, telling you the observations. What I want to do is define a posterior over the probability that the random die chosen is dice n after observations \mathcal{D}. Assume all dice are fair.

This is distinctly different than the problem of a single unfair dice and trying to infer the probability of observing a side, since it is always possible that any side has p>0. In this problem, say you observe a 5. There is no way it could be a 4 or 3-sided die. Those probabilities are 0. This is the particular piece I’m having trouble coding up. The update scheme using Bayes’s rule is quite easy, but I’m not sure how to put this into numpyro.

My starting point is just simply a Dirichlet distribution with equal probability that the random dice selected was die n.

N_DICE = 5  # sides [3, 4, 5, 6, 7]
def model(observations):
    # observations might be something like [4, 5, 4, 6, 4]
    theta = numpyro.sample('theta', dist.Dirichlet(np.ones(N_DICE)))
    with numpyro.plate('observations', len(observations)):
        # Unsure what goes here

Is a custom distribution or something required?

Thanks so much for the help!

do you assume each die is fair?

@martinjankowiak yes, all dice are fair!

i think all you need is a model with two components.

the first is a prior over a scalar discrete latent variable (say uniform) over {0, 1, …, N-1}. call this latent variable die.

then, conditioned on die you define a likelihood (e.g. via an observed sample statement).

in this case since your likelihood is a bit funky it’s probably best to use a factor statement. to do that you just need to compute the scalar log_prob that corresponds to the die-conditioned likelihood. this will be -infinity if the observation is impossible; it will be Multinomial.log_prob(...) otherwise.

@martinjankowiak I’ll take a shot at this but would definitely appreciate it if you could humor me a bit here since it’s a weird case (as you say the likelihood is a bit “funky”). Sounds like you’re implying I should define a new Distribution that implements a custom log_prob method, but I’m not sure. I understand if you don’t have the time. Thanks!

you don’t need a custom distribution. you can use factor. something like

die = numpyro.sample("die", ...)
log_prob = jnp.where(die_could_plausibly_generate_data(die, data), 
numpyro.factor("my_observation", log_prob)

where die_could_plausibly_generate_data(...) is a function you need to define