I’ve tried solving exercise Statistical Rethinking 15H7. I have a solution, I’m just not overly keen on it.
Here’s the exercise:
Some lad named Andrew made an eight-sided spinner. He wanted to know if it is fair. So he
spun it a bunch of times, recording the counts of each value. Then he accidentally spilled coffee over
the 4s and 5s. The surviving data are summarized below.Value 1 2 3 4 5 6 7 8
Frequency 18 19 22 ? ? 19 20 22Your job is to impute the two missing values in the table above. Andrew doesn’t remember how many
times he spun the spinner. So you will have to assign a prior distribution for the total number of spins
and then marginalize over the unknown total. Andrew is not sure the spinner is fair (every value is
equally likely), but he’s confident that none of the values is twice as likely as any other. Use a Dirichlet
distribution to capture this prior belief. Plot the joint posterior distribution of 4s and 5s.
The way I’ve attempted this is to model each frequency as a Poisson distribution, where the rate is given by the probability of a side being picked multiplied by the total number of times the spinner was spun:
import arviz as az
import jax
import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
import pandas as pd
from numpyro.infer import MCMC, NUTS, Predictive
numpyro.set_host_device_count(4)
df = pd.DataFrame(
{
"Value": [0, 1, 2, 3, 4, 5, 6, 7],
"Frequency": [18, 19, 22, np.nan, np.nan, 19, 20, 22],
}
)
def model(value, frequency):
n_missing = int(np.isnan(df["Frequency"]).sum())
idx_missing = np.where(np.isnan(df["Frequency"]))[0]
n_vals = len(np.unique(value))
n_spins = numpyro.sample("n_spins", dist.Normal(160, 20))
probs = numpyro.sample("probs", dist.Dirichlet(jnp.ones(n_vals) * 56))
sigma = numpyro.sample("sigma", dist.HalfNormal(5))
rate = numpyro.deterministic("rate", probs * n_spins)
frequency_impute = numpyro.sample(
"frequency_impute", dist.Normal(20, 5).expand([n_missing]).mask(False)
)
frequency_imputed = jax.ops.index_update(frequency, idx_missing, frequency_impute)
numpyro.sample("frequency", dist.Poisson(rate[value]), obs=frequency_imputed)
mcmc = MCMC(
NUTS(model, target_accept_prob=0.9),
num_chains=4,
num_samples=1000,
num_warmup=1000,
)
mcmc.run(
jax.random.PRNGKey(0),
value=df["Value"].to_numpy(),
frequency=df["Frequency"].to_numpy(),
)
mcmc.print_summary()
This “works” in the sense that it provides reasonable looking imputations for the missing frequencies:
mean std median 5.0% 95.0% n_eff r_hat
frequency_impute[0] 20.09 5.52 19.70 10.74 28.45 2961.27 1.00
frequency_impute[1] 20.15 5.48 19.81 11.22 29.00 3089.66 1.00
n_spins 160.77 12.07 160.42 141.51 180.78 3246.51 1.00
probs[0] 0.12 0.01 0.12 0.10 0.14 4263.82 1.00
probs[1] 0.12 0.01 0.12 0.10 0.15 4613.60 1.00
probs[2] 0.13 0.01 0.13 0.11 0.15 4068.24 1.00
probs[3] 0.13 0.02 0.12 0.10 0.15 3381.19 1.00
probs[4] 0.13 0.02 0.12 0.10 0.15 3462.87 1.00
probs[5] 0.12 0.01 0.12 0.10 0.14 4129.95 1.00
probs[6] 0.13 0.01 0.12 0.10 0.15 4117.11 1.00
probs[7] 0.13 0.01 0.13 0.11 0.15 4036.46 1.00
sigma 4.05 3.05 3.47 0.00 8.31 4387.28 1.00
Number of divergences: 0
However, I’m bothered by the fact that n_spins
is modelled as a Normal distribution, whereas we know that it can only take on integer counts.
Is there a better way to solve this problem?