Initially, the issue was detected by inspecting output printed by .run(). The following code perfectly reproduces it with a simple example.
from jax import random
from numpyro import distributions as dist
from numpyro import sample
from numpyro.infer import MCMC, NUTS
import numpy as np
import jax.numpy as jnprng_key, rng_key_predict = random.split(random.PRNGKey(1))
n = 10
x = np.random.normal(0, 1, n)def model_mask(x):
mu = sample(“mu”, dist.Normal(0, 1))
sigma = sample(“sigma”, dist.HalfNormal(1))
sample(‘x’, dist.Normal(mu, sigma), obs=x, obs_mask=~jnp.isnan(x))def model_nomask(x):
mu = sample(“mu”, dist.Normal(0, 1))
sigma = sample(“sigma”, dist.HalfNormal(1))
sample(‘x’, dist.Normal(mu, sigma), obs=x)kernel_mask = NUTS(model_mask)
mcmc_mask = MCMC(kernel_mask,
num_warmup = 100,
num_samples = 100,
num_chains = 1,
progress_bar = True)kernel_nomask = NUTS(model_nomask)
mcmc_nomask = MCMC(kernel_nomask,
num_warmup = 100,
num_samples = 100,
num_chains = 1,
progress_bar = True)mcmc_mask.run(rng_key, x,
extra_fields = (‘z’,‘z_grad’,‘potential_energy’,‘energy’,‘r’,‘trajectory_length’,‘num_steps’,
‘adapt_state’))mcmc_nomask.run(rng_key, x,
extra_fields=(‘z’, ‘z_grad’, ‘potential_energy’, ‘energy’, ‘r’, ‘trajectory_length’, ‘num_steps’,
‘adapt_state’))
Models model_mask and model_nomask are different only through attribute obs_mask in function sample. Data taken for estimation is the same in both cases and doesn’t contain any missing values.
Thus, mcmc_mask and mcmc_nomask are expected to return very similar behaviors of the NUTS sampler and hence results. But, ‘num_steps’ field from mcmc_mask.get_extra_fields() returns (in all cases) value 1023, i.e. for each sampled data point the maximum depth of the tree was hit. The same field from mcmc_nomask.get_extra_fields() returns value equal or less than 7.
Own inspections of the code lead to the detection of the following issue. The momentum vector r (as in Hoffmann&Gelman 2014) is expected to have length equal to number of random variables, i.e. mu and sigma. By inspecting r = momentum_generator(z, wa_state.mass_matrix_sqrt, rng_key_momentum) from class hmc I got the following input
for mcmc_mask:
{‘mu’: DeviceArray(-2.8184042, dtype=float32), ‘sigma’: DeviceArray(0.8608751, dtype=float32), ‘x_unobserved’: DeviceArray([ 0.55401844, 0.063475 , -0.00875445, -1.1976603 ,
1.601614 , 1.2322016 , -0.5856139 , 0.4198636 ,
-0.31236285, 0.08368913], dtype=float32)}
for mcmc_nomask:
{‘mu’: DeviceArray(0.74702597, dtype=float32), ‘sigma’: DeviceArray(-0.6659205, dtype=float32)}
What I read is that in case of simply leaving obs_mask not empty 10 new random variables (number of samples) were included in estimation. Including some missing data still leads to the same behavior. That has repercussions for calculating kinetic energy.
Can you please help clarify this issue?