Parent nodes with missing data

Following up on @fritzo’s responses here and here on how to handle missing (MCAR) data. Sequential observations would be straight-froward but would hamper runtime too much. poutine.mask is great and I’ve been able to get it to work filtering out missing values from leaf nodes, as in the linked examples.

What I’m currently stuck on, however, is how to deal with partially-observed nodes that have children. I can replace the missing data with valid dummy values (as in the second link) and use poutine.mask to those exclude those dummy values from the ELBO calculation. But then any children of the masked node will be using those dummy values, which would throw off training.

Is there something I’m missing in terms of partially-observing parent nodes?

One thing we’ve tried is to explicitly model the masking, as below. The original partially-observed variable is replaced by a replica called latent_var, which is unobserved. A new observed binary missing_var takes the masking data. A new observed_var is an observed child of those two, takes the observed data originally intended for the original variable, and is a pyro.deterministic node with the element-wise function latent_var if not missing_var else valid_dummy_value.

 (original parents)          (original parents)
        |                           |
        v                           v
 (original variable)  ===>     (latent_var) --> (observed_var) <-- (missing_var)
        |                           |
        v                           v
 (original children)         (original children)

One issue off the bat is that during training gradients can’t flow through the Delta distribution of the deterministic node. Replacing it with a very tight normal distribution should give a decent approximation. Even still, this approach doesn’t give useful results at prediction time, always predicting a single outcome for latent_var across all data instances.

Hi @gbernstein,
great question! I think what you’re describing is a feature we’ve discussed under the name masked sample statement, and is not yet implemented. When I’ve needed this feature in models, I usually hand-implement the masking using poutine.mask, two sample sites, and torch.where. For example starting with a non-masked model

child_data = torch.tensor(..., dtype=torch.float)

def basic_model(child_data):
    size = len(child_data.shape)
    with pyro.plate("plate", size):
        parent_dist = dist.Normal(0, 1).expand([size])
        parent = pyro.sample("parent", parent_dist)
        pyro.sample("child", dist.Normal(parent, 1),
                    obs=child_data)

we can transform to mask the parent

parent_data = torch.tensor(..., dtype=torch.float)
parent_mask = torch.tensor(..., dtype=torch.bool)

def masked_model(child_data, parent_data, parent_mask):
    size = len(child_data.shape)
    with pyro.plate("plate", size):
        parent_dist = dist.Normal(0, 1).expand([size])
        with poutine.mask(mask=parent_mask):
            parent_1 = pyro.sample("parent_1", parent_dist, obs=parent_data)
        with poutine.mask(mask=~parent_mask):
            parent_0 = pyro.sample("parent_0", parent_dist)  # no obs
        parent = pyro.deterministic(
            "parent",
            torch.where(parent_mask, parent_1, parent_0),
        )
        pyro.sample("child", dist.Normal(parent, 1), obs=child_data)
1 Like

Super helpful with the before and after, thanks a ton @fritzo! And I vote for that masked sample statement feature. I’m surprised more people haven’t run into that kind of issue.

A followup warning for future post visitors. This solution requires a very small tweak mentioned in Fritzo’s previous posts (linked above in my first post) to get it to work when changing the parent distribution to Categorical: Before sampling the observed parent data in the model, you must replace the nan values.

Nitty-gritty details: When calculating the loss, the log-probability is calculated for each observation, and then only poutine.masked log-probabilities are used. This is fine for a Normal variable because the torch.distributions.normal.log_prob return statement can gracefully handle nans. The torch.distributions.categorical.log_prob return statement, however, cannot gracefully handle nans because it tries to use it as an index in torch.gather. The presence of nans in the data will indeed raise an exception if Pyro validation is enabled.

Since the corresponding log-probability entry will be masked out anyways, the solution is to replace the nan with any valid dummy category to make gather happy. We can confirm the choice of dummy value doesn’t affect anything by setting the random seed and printing the loss when replacing nans with different valid dummy values.

MWE
import pyro
from pyro import poutine
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
import numpy as np

plate_size = 10
prior = torch.tensor([.2, .8])

def basic_model(child_data):
    with pyro.plate("plate", plate_size):
        parent_dist = dist.Categorical(prior).expand([plate_size])
        parent = pyro.sample("parent", parent_dist)
        pyro.sample("child", dist.Normal(parent.float(), 1),
                    obs=child_data)


def basic_guide(child_data):
    with pyro.plate("plate", plate_size):
        parent_dist = dist.Categorical(prior).expand([plate_size])
        parent = pyro.sample("parent", parent_dist)


def masked_model(child_data, parent_data, parent_mask):
    with pyro.plate("plate", plate_size):
        parent_dist = dist.Categorical(prior).expand([plate_size])
        with poutine.mask(mask=parent_mask):
            parent_1 = pyro.sample("parent_1", parent_dist, obs=parent_data).float()
        with poutine.mask(mask=~parent_mask):
            parent_0 = pyro.sample("parent_0", parent_dist).float()  # no obs
        parent = pyro.deterministic("parent", torch.where(parent_mask, parent_1, parent_0))
        pyro.sample("child", dist.Normal(parent, 1), obs=child_data)


def masked_guide(child_data, parent_data, parent_mask):
    with pyro.plate("plate", plate_size):
        parent_dist = dist.Categorical(prior).expand([plate_size])
        parent_1 = parent_data.float()
        with poutine.mask(mask=~parent_mask):
            parent_0 = pyro.sample("parent_0", parent_dist).float()  # no obs
        parent = pyro.deterministic("parent", torch.where(parent_mask, parent_1, parent_0))


child_data = torch.ones(plate_size, dtype=torch.float)

svi = SVI(basic_model, basic_guide, Adam({}), loss=Trace_ELBO())
print(svi.step(child_data))

parent_data = torch.ones(plate_size, dtype=torch.float)
parent_data[0] = np.nan
parent_mask = ~torch.isnan(parent_data)

# replace nans with valid dummy value now that we have the mask
parent_data[~parent_mask] = 0

svi = SVI(masked_model, masked_guide, Adam({}), loss=Trace_ELBO())
print(svi.step(child_data, parent_data, parent_mask))

1 Like

Looks like the above fix of replacing to-be-masked nan values now must also be used for non-Categorical distributions. As far as I can tell it’s because while calculating logprob the sampler will first check the validity of all observations, regardless of any poutine masking…

MWE
import pyro
import pyro.distributions as dist
from pyro.infer import Trace_ELBO
import torch

obs = torch.tensor((0, 0, float('nan')))  # <--- yields error (shown below)
obs = torch.tensor((0, 0, 20))            # <--- no error; 20 will be ignored in loss calculation because of mask
mask = torch.tensor((True, True, False)

def model():
    with pyro.plate('plate', size=len(obs)):
        with pyro.poutine.mask(mask=mask):
            return pyro.sample('node', dist.Normal(0, 1), obs=obs)


def guide():
    pass


print(Trace_ELBO().loss(model, guide)). <--- prints 1.83
Error trace
ValueError: Error while computing log_prob at site 'node':
The value argument must be within the support
Trace Shapes:    
 Param Sites:    
Sample Sites:    
    node dist 3 |
        value 3 |

i believe you can just disable validation; see enable_validation

Ah, that makes sense too. Though seems like too broad a brush stroke for most cases. I could maybe disable validation for prod but I’d still want it for cases like tests.

I just wish the poutine mask was propagated further to things like logprob calculations.

yeah i’m not sure what the best solution would be. possibilities include:

  • ignore validation when mask is in play
  • expand validation options to different levels of strictness, one of which corresponds to the above
  • others? @fritzo @eb8680_2 @fehiepsi thoughts?

In numpyro, we implemented feasible_like for constraints to deal with this issue. It would be a nice request for the same thing in PyTorch distributions.

Note you can disable validation at just a single sample site via e.g.

dist.Normal(loc, scale, validate_args=False)

I suppose in the long run a clean solution would be to add a constraints.masked(). Here, I’ll open a Pyro issue.

Good to know, thanks.

Related issue: The ELBO loss return from pyro.infer.SVI.step() comes out nan as well even if the nan sites are all masked and the learned parameters come out fine.

/…/lib/python3.7/site-packages/pyro/infer/trace_elbo.py:143: UserWarning: Encountered NaN: loss
warn_if_nan(loss, “loss”)