Impute missing counts (Statistical Rethinking 15H7)

I’ve tried solving exercise Statistical Rethinking 15H7. I have a solution, I’m just not overly keen on it.

Here’s the exercise:

Some lad named Andrew made an eight-sided spinner. He wanted to know if it is fair. So he
spun it a bunch of times, recording the counts of each value. Then he accidentally spilled coffee over
the 4s and 5s. The surviving data are summarized below.

Value 1 2 3 4 5 6 7 8
Frequency 18 19 22 ? ? 19 20 22

Your job is to impute the two missing values in the table above. Andrew doesn’t remember how many
times he spun the spinner. So you will have to assign a prior distribution for the total number of spins
and then marginalize over the unknown total. Andrew is not sure the spinner is fair (every value is
equally likely), but he’s confident that none of the values is twice as likely as any other. Use a Dirichlet
distribution to capture this prior belief. Plot the joint posterior distribution of 4s and 5s.

The way I’ve attempted this is to model each frequency as a Poisson distribution, where the rate is given by the probability of a side being picked multiplied by the total number of times the spinner was spun:

import arviz as az
import jax
import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
import pandas as pd
from numpyro.infer import MCMC, NUTS, Predictive

numpyro.set_host_device_count(4)

df = pd.DataFrame(
    {
        "Value": [0, 1, 2, 3, 4, 5, 6, 7],
        "Frequency": [18, 19, 22, np.nan, np.nan, 19, 20, 22],
    }
)


def model(value, frequency):
    n_missing = int(np.isnan(df["Frequency"]).sum())
    idx_missing = np.where(np.isnan(df["Frequency"]))[0]
    n_vals = len(np.unique(value))

    n_spins = numpyro.sample("n_spins", dist.Normal(160, 20))
    probs = numpyro.sample("probs", dist.Dirichlet(jnp.ones(n_vals) * 56))
    sigma = numpyro.sample("sigma", dist.HalfNormal(5))

    rate = numpyro.deterministic("rate", probs * n_spins)
    frequency_impute = numpyro.sample(
        "frequency_impute", dist.Normal(20, 5).expand([n_missing]).mask(False)
    )
    frequency_imputed = jax.ops.index_update(frequency, idx_missing, frequency_impute)
    numpyro.sample("frequency", dist.Poisson(rate[value]), obs=frequency_imputed)


mcmc = MCMC(
    NUTS(model, target_accept_prob=0.9),
    num_chains=4,
    num_samples=1000,
    num_warmup=1000,
)
mcmc.run(
    jax.random.PRNGKey(0),
    value=df["Value"].to_numpy(),
    frequency=df["Frequency"].to_numpy(),
)
mcmc.print_summary()

This “works” in the sense that it provides reasonable looking imputations for the missing frequencies:


                         mean       std    median      5.0%     95.0%     n_eff     r_hat
frequency_impute[0]     20.09      5.52     19.70     10.74     28.45   2961.27      1.00
frequency_impute[1]     20.15      5.48     19.81     11.22     29.00   3089.66      1.00
            n_spins    160.77     12.07    160.42    141.51    180.78   3246.51      1.00
           probs[0]      0.12      0.01      0.12      0.10      0.14   4263.82      1.00
           probs[1]      0.12      0.01      0.12      0.10      0.15   4613.60      1.00
           probs[2]      0.13      0.01      0.13      0.11      0.15   4068.24      1.00
           probs[3]      0.13      0.02      0.12      0.10      0.15   3381.19      1.00
           probs[4]      0.13      0.02      0.12      0.10      0.15   3462.87      1.00
           probs[5]      0.12      0.01      0.12      0.10      0.14   4129.95      1.00
           probs[6]      0.13      0.01      0.12      0.10      0.15   4117.11      1.00
           probs[7]      0.13      0.01      0.13      0.11      0.15   4036.46      1.00
              sigma      4.05      3.05      3.47      0.00      8.31   4387.28      1.00

Number of divergences: 0

However, I’m bothered by the fact that n_spins is modelled as a Normal distribution, whereas we know that it can only take on integer counts.

Is there a better way to solve this problem?

he’s confident that none of the values is twice as likely as any other

I think we can utilize this assumption to set some truncated priors for the frequency, then marginalize out the missing information. Something like

def model(...):
    probs = sample("probs", Dirichlet(...))
    # I use Categorical here but it is just a discrete uniform distribution over [0..39];
    # maybe having DiscreteUniform distribution is a nice feature request
    # because its sampler will be much faster than the categorical samplers.
    freq_3 = sample("freq_3", Categorical(logits=jnp.zeros(40)).mask(False))
    freq_4 = sample("freq_4", Categorical(logits=jnp.zeros(40)).mask(False))
    # optionally add a "truncated" Poisson prior to those missing frequencies
    # numpyro.factor("freq_3_factor", Poisson(rate=20).log_prob(freq_3))
    frequency_imputed = impute_fn(frequency, freq_3, freq_4) 
    numpyro.sample("obs", dist.Multinomial(frequency_imputed.sum(-1), probs),
                   obs=frequency_imputed)

Running MCMC(NUTS) will automatically marginalize out the missing values and give you posterior over “rate”. Then you can use infer_discrete to get samples for those missing values. Alternatively, you can run MCMC(DiscreteHMCGibbs(NUTS)) to get posteriors over both “rate” and those missing values (but the problem asked to marginalize out the counts, so I would prefer the former solution).

1 Like