Poutine nan mask not working

Hi. I could have sworn I’ve previously gotten poutine masking working to ignore nan observations, but I can’t do so in this super simple toy example. The parameters for the first 9 parent instances are all learned correctly, but the last is learned as nan. My intention was for the the single nan child observation to be masked out and for the last parent instance to be learned only using the other 19 non-nan child observations.

Am I missing something obvious? Did something change/break in a recent release?

MWE
import torch
import pyro
import pyro.distributions as dist

pyro.enable_validation(False)

num_parent = 10
num_child = 20


def model(obs):
    with pyro.plate('parent_plate', size=num_parent, dim=-1):
        parent = pyro.sample('parent', dist.Normal(0.5, 0.1))

        with pyro.plate('child_plate', size=num_child, dim=-2):

            with pyro.poutine.mask(mask=~torch.isnan(obs)):
                pyro.sample(f'child', dist.Normal(parent, 0.01), obs=obs)


def guide(obs):
    with pyro.plate('parent_plate', size=num_parent, dim=-1):
        mean = pyro.param('parent_mean', torch.full((num_parent,), .5))
        pyro.sample('parent', dist.Normal(mean, 0.1))


def main():
    obs = torch.linspace(0, 1, num_parent).expand((num_child, num_parent))
    obs[-1, -1] = float('nan')

    svi = pyro.infer.SVI(
        model,
        guide,
        pyro.optim.Adam({}),
        loss=pyro.infer.Trace_ELBO(),
    )

    for _ in range(1000):
        svi.step(obs)

    param_store = pyro.get_param_store()

    mean = param_store['parent_mean']
    parent_pred = [round(float(i), 2) for i in mean]

    print('parent prediction:', parent_pred)


if __name__ == '__main__':
    main()

Due to a longstanding issue / bug in PyTorch, NAN gradients can propoagate to non-nan gradients even when masked out. The workaround in Pyro is to combine poutine.mask (as you have done) together with replacing nan obs values with some other value that is non NAN (e.g. zero).

1 Like

Ahh I had that fix in my real code from our previous discussions, but forgot to include it in the toy example. The extra issue in my real code though was that I was using an invalid (on purpose) dummy value, e.g. -12345 for a (0,1)-valued observation, so that was still creating nan gradients. Now I’m not sure how my code was working before with that invalid dummy value, if it even was, but regardless, replacing it with a valid dummy value works. Thanks again @fritzo!

2 Likes