Please help me understand the behavior of the mask handler. I observe that masked values that are invalid for the distribution (e.g. equal to 0.0 for the Beta distribution) seem to make NUTS inference fail (always diverge). Is this a bug or am I interpreting masks incorrectly?
Working example here:
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
import jax
import jax.numpy as np
def model(obs=None):
conc = numpyro.sample('conc', dist.Gamma(1, 1))
a = conc * np.ones(3,)
mask = np.arange(3) > 0 # first observation should be ignored
with numpyro.handlers.mask(mask_array=mask):
y = numpyro.sample("y", dist.Beta(a, a), obs=obs)
return y
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=10, num_samples=10, num_chains=1)
# Expected behavior:
# -- obs1, obs2, obs3 are equivalent
obs1 = np.array([0.1, 0.5, 0.5])
obs2 = np.array([0.5, 0.5, 0.5])
obs3 = np.array([0.0, 0.5, 0.5]) # invalid
init = {'conc': 1.} # fails to find initial parameters otherwise
for obs in [obs1, obs2, obs3]:
mcmc.run(jax.random.PRNGKey(2), obs=obs, init_params=init)
mcmc.print_summary()
print(mcmc.get_samples()['conc'])
# Observed behavior
# -- obs1, obs2 are equivalent
# -- inference fails for obs3 (always diverges)