Behavior of mask handler with invalid observation --- possible bug?

Please help me understand the behavior of the mask handler. I observe that masked values that are invalid for the distribution (e.g. equal to 0.0 for the Beta distribution) seem to make NUTS inference fail (always diverge). Is this a bug or am I interpreting masks incorrectly?

Working example here:

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

import jax
import jax.numpy as np

def model(obs=None):
    conc = numpyro.sample('conc', dist.Gamma(1, 1))
    a = conc * np.ones(3,) 
    mask = np.arange(3) > 0 # first observation should be ignored
    with numpyro.handlers.mask(mask_array=mask):
        y = numpyro.sample("y", dist.Beta(a, a), obs=obs)
    return y

kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=10, num_samples=10, num_chains=1)

# Expected behavior: 
#  -- obs1, obs2, obs3 are equivalent
obs1 = np.array([0.1, 0.5, 0.5])
obs2 = np.array([0.5, 0.5, 0.5])
obs3 = np.array([0.0, 0.5, 0.5]) # invalid

init = {'conc': 1.}  # fails to find initial parameters otherwise

for obs in [obs1, obs2, obs3]:, obs=obs, init_params=init)

# Observed behavior
#  -- obs1, obs2 are equivalent
#  -- inference fails for obs3 (always diverges)

I think this is the case 0 (mask) times inf (log_prob) is nan (or something like that). We can fix the issue in NumPyro by using np.where(mask, log_prob, 0.) to calculate the masked log probability, but I am not sure about if we should support models with invalid observation.

@neerajprad do you have a thought?

I debugged a bit and was under the impression it was already being handled (seemingly correctly) by np.where. I added a print statement to this code path in numpyro/infer/ starting on Line 83

            if mask is not None:
                if scale is not None:
                    log_prob = np.where(mask, scale * log_prob, 0.)
                    log_prob = np.where(mask, log_prob, 0.)
                    print("log prob:", log_prob)

And in all three cases it computed the same log_prob (the first value was set to zero), which seemed to be the correct behavior:

log prob: Traced<ConcreteArray([0. 0.57504654 0.57504654])>with<JVPTrace(level=1/0)>

I added additional diagnostics in init_fn within velocity_verlet in

    def init_fn(z, r):
        :param z: Position of the particle.
        :param r: Momentum of the particle.
        :return: initial state for the integrator.
        potential_energy, z_grad = value_and_grad(potential_fn)(z)
        print("potential_energy: ", potential_energy)
        print("z_grad: ", z_grad)
        return IntegratorState(z, r, potential_energy, z_grad)

In all cases it got the same potential energy:

potential_energy: 0.56818914

In the first two cases (valid observation) it got the same gradient:

z_grad: {'conc': DeviceArray(0.6277808, dtype=float32)}

In the third case (invalid observation) it got nan for the gradient:

z_grad: {'conc': DeviceArray(nan, dtype=float32)}

I don’t understand how these code paths are reached, though, so couldn’t make further progress.

Whoa, thanks for debugging!! I think I can understand the issue now. It seems that we should do something like:

np.where(mask, compute_log_prob_here(), 0.)

instead of

log_prob = compute_log_prob()
np.where(mask, log_prob, 0.)

This is a known issue by jax and I think it appears in tensorflow, pytorch too. I’ll make a fix for it.

Excellent! Thanks for pointing this out. This also means there is an easy workaround for now.

On the user side I can use np.where to replace bad values, and then also use the mask handler to ignore the effect of the replaced values on the log density and it gradient. E.g., if I do this in my original code:

    if obs is not None:
        obs = np.where((obs > 0) & (obs < 1), obs, 0.5)
    with numpyro.handlers.mask(mask_array=mask):
        y = numpyro.sample("y", dist.Beta(a, a), obs=obs)

then NUTS behaves as expected and returns identical results for all three sets of observations.

1 Like

@sheldon Please ignore my last comment, I found no easy way to remedy the issue. :frowning: It seems that we should provide valid observations, like your last comment.

FYI, with scalar observation, we can use

log_prob = jax.lax.cond(mask,
                        value, lambda x: site['fn'].log_prob(x),
                        value, lambda x: 0.)

but for array obervations, we have to vectorize lax.cond (because it only accepts scalar condition). This is problem because mask, value, site['fn'].log_prob(value) can have different shapes (though mask and site['fn'].log_prob(value) shapes are broadcastable). We have to use information of site['fn'] to handle batch shape/event shape of value, broadcast mask to the corresponding batch shape, reshape them to 1D batch arrays to apply jax.vmap(jax.lax.cond(...)), then reshape the result. This is much overhead for a simple masking. :frowning:

Hi @fehiepsi. Thanks for looking into it! Having to substitute a default value prior to evaluating the function seems consistent with the workarounds in jax and tensorflow (found from the link you previously posted):

def my_log_or_y(x, y):
    return np.where(x > 0., np.log(np.where(x > 0., x, 1.), y)
tf​.​where​(​x_ok​,​ f​(​tf​.​where​(​x_ok​,​ x​,​ safe_x​)),​ safe_f(x)​)

But wouldn’t it be possible for numpyro to substitute a default value prior to computing the log density? I guess the trick is you would need a way to find a valid default value for each distribution — I don’t know the details enough to know whether this is possible.

Anyway, I can easily do the workaround on the user side for now. I just think that the mask handler is a nice way to handle missing values and its non-obvious for the user that they would need to use a mask and replace values.

Thanks again for your help.