Bayesian imputation - what does the masked distribution do?

Here’s an example, adapted from the docs:

import numpyro
import numpyro.distributions as dist
import jax.numpy as jnp
import numpy as np
import jax
import seaborn as sns

x = np.hstack([np.array([np.nan]*4), np.random.randn(6)*2+5])

def model2b(x):
    x_impute = numpyro.sample("x_impute", dist.Normal(0, 1).expand([4]).mask(False))
    x_imputed = jnp.concatenate([x_impute, x[4:]])
    numpyro.sample("x", dist.Normal(0, 1), obs=x_imputed)

mcmc = numpyro.infer.MCMC(
    numpyro.infer.NUTS(model2b),
    num_chains=1,
    num_samples=1000,
    num_warmup=1000,
)
mcmc.run(jax.random.PRNGKey(0), x=x)

mcmc.print_summary()

If I run this, I get:

                 mean       std    median      5.0%     95.0%     n_eff     r_hat
x_impute[0]      0.04      0.97      0.03     -1.52      1.68   1070.07      1.00
x_impute[1]     -0.01      0.94      0.01     -1.66      1.45   1037.47      1.00
x_impute[2]     -0.03      0.98     -0.05     -1.45      1.76   1048.96      1.00
x_impute[3]     -0.01      0.95     -0.02     -1.55      1.54   1052.58      1.00

Number of divergences: 0

If I change the line

    x_impute = numpyro.sample("x_impute", dist.Normal(0, 1).expand([4]).mask(False))

to

    x_impute = numpyro.sample("x_impute", dist.Normal(500, 1200).expand([4]).mask(False))

and re-run, then I still get the same output.

So my question is - what does the distribution specified there do? Does it matter what one puts?

With mask(False), it does not matter because the log density at all point in the support is 0. The only one that matters is their support. It also matters if you try to get samples for prior distributions at those sites, or init the inference algorithms with init_to_sample, init_to_median strategies.

1 Like

That’s very clear, thanks @fehiepsi !