Here’s an example, adapted from the docs:
import numpyro
import numpyro.distributions as dist
import jax.numpy as jnp
import numpy as np
import jax
import seaborn as sns
x = np.hstack([np.array([np.nan]*4), np.random.randn(6)*2+5])
def model2b(x):
x_impute = numpyro.sample("x_impute", dist.Normal(0, 1).expand([4]).mask(False))
x_imputed = jnp.concatenate([x_impute, x[4:]])
numpyro.sample("x", dist.Normal(0, 1), obs=x_imputed)
mcmc = numpyro.infer.MCMC(
numpyro.infer.NUTS(model2b),
num_chains=1,
num_samples=1000,
num_warmup=1000,
)
mcmc.run(jax.random.PRNGKey(0), x=x)
mcmc.print_summary()
If I run this, I get:
mean std median 5.0% 95.0% n_eff r_hat
x_impute[0] 0.04 0.97 0.03 -1.52 1.68 1070.07 1.00
x_impute[1] -0.01 0.94 0.01 -1.66 1.45 1037.47 1.00
x_impute[2] -0.03 0.98 -0.05 -1.45 1.76 1048.96 1.00
x_impute[3] -0.01 0.95 -0.02 -1.55 1.54 1052.58 1.00
Number of divergences: 0
If I change the line
x_impute = numpyro.sample("x_impute", dist.Normal(0, 1).expand([4]).mask(False))
to
x_impute = numpyro.sample("x_impute", dist.Normal(500, 1200).expand([4]).mask(False))
and re-run, then I still get the same output.
So my question is - what does the distribution specified there do? Does it matter what one puts?