Parent nodes with missing data

Following up on @fritzo’s responses here and here on how to handle missing (MCAR) data. Sequential observations would be straight-froward but would hamper runtime too much. poutine.mask is great and I’ve been able to get it to work filtering out missing values from leaf nodes, as in the linked examples.

What I’m currently stuck on, however, is how to deal with partially-observed nodes that have children. I can replace the missing data with valid dummy values (as in the second link) and use poutine.mask to those exclude those dummy values from the ELBO calculation. But then any children of the masked node will be using those dummy values, which would throw off training.

Is there something I’m missing in terms of partially-observing parent nodes?

One thing we’ve tried is to explicitly model the masking, as below. The original partially-observed variable is replaced by a replica called latent_var, which is unobserved. A new observed binary missing_var takes the masking data. A new observed_var is an observed child of those two, takes the observed data originally intended for the original variable, and is a pyro.deterministic node with the element-wise function latent_var if not missing_var else valid_dummy_value.

 (original parents)          (original parents)
        |                           |
        v                           v
 (original variable)  ===>     (latent_var) --> (observed_var) <-- (missing_var)
        |                           |
        v                           v
 (original children)         (original children)

One issue off the bat is that during training gradients can’t flow through the Delta distribution of the deterministic node. Replacing it with a very tight normal distribution should give a decent approximation. Even still, this approach doesn’t give useful results at prediction time, always predicting a single outcome for latent_var across all data instances.

Hi @gbernstein,
great question! I think what you’re describing is a feature we’ve discussed under the name masked sample statement, and is not yet implemented. When I’ve needed this feature in models, I usually hand-implement the masking using poutine.mask, two sample sites, and torch.where. For example starting with a non-masked model

child_data = torch.tensor(..., dtype=torch.float)

def basic_model(child_data):
    size = len(child_data.shape)
    with pyro.plate("plate", size):
        parent_dist = dist.Normal(0, 1).expand([size])
        parent = pyro.sample("parent", parent_dist)
        pyro.sample("child", dist.Normal(parent, 1),
                    obs=child_data)

we can transform to mask the parent

parent_data = torch.tensor(..., dtype=torch.float)
parent_mask = torch.tensor(..., dtype=torch.bool)

def masked_model(child_data, parent_data, parent_mask):
    size = len(child_data.shape)
    with pyro.plate("plate", size):
        parent_dist = dist.Normal(0, 1).expand([size])
        with poutine.mask(mask=parent_mask):
            parent_1 = pyro.sample("parent_1", parent_dist, obs=parent_data)
        with poutine.mask(mask=~parent_mask):
            parent_0 = pyro.sample("parent_0", parent_dist)  # no obs
        parent = pyro.deterministic(
            "parent",
            torch.where(parent_mask, parent_1, parent_0),
        )
        pyro.sample("child", dist.Normal(parent, 1), obs=child_data)
1 Like

Super helpful with the before and after, thanks a ton @fritzo! And I vote for that masked sample statement feature. I’m surprised more people haven’t run into that kind of issue.

A followup warning for future post visitors. This solution requires a very small tweak mentioned in Fritzo’s previous posts (linked above in my first post) to get it to work when changing the parent distribution to Categorical: Before sampling the observed parent data in the model, you must replace the nan values.

Nitty-gritty details: When calculating the loss, the log-probability is calculated for each observation, and then only poutine.masked log-probabilities are used. This is fine for a Normal variable because the torch.distributions.normal.log_prob return statement can gracefully handle nans. The torch.distributions.categorical.log_prob return statement, however, cannot gracefully handle nans because it tries to use it as an index in torch.gather. The presence of nans in the data will indeed raise an exception if Pyro validation is enabled.

Since the corresponding log-probability entry will be masked out anyways, the solution is to replace the nan with any valid dummy category to make gather happy. We can confirm the choice of dummy value doesn’t affect anything by setting the random seed and printing the loss when replacing nans with different valid dummy values.

MWE
import pyro
from pyro import poutine
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
import numpy as np

plate_size = 10
prior = torch.tensor([.2, .8])

def basic_model(child_data):
    with pyro.plate("plate", plate_size):
        parent_dist = dist.Categorical(prior).expand([plate_size])
        parent = pyro.sample("parent", parent_dist)
        pyro.sample("child", dist.Normal(parent.float(), 1),
                    obs=child_data)


def basic_guide(child_data):
    with pyro.plate("plate", plate_size):
        parent_dist = dist.Categorical(prior).expand([plate_size])
        parent = pyro.sample("parent", parent_dist)


def masked_model(child_data, parent_data, parent_mask):
    with pyro.plate("plate", plate_size):
        parent_dist = dist.Categorical(prior).expand([plate_size])
        with poutine.mask(mask=parent_mask):
            parent_1 = pyro.sample("parent_1", parent_dist, obs=parent_data).float()
        with poutine.mask(mask=~parent_mask):
            parent_0 = pyro.sample("parent_0", parent_dist).float()  # no obs
        parent = pyro.deterministic("parent", torch.where(parent_mask, parent_1, parent_0))
        pyro.sample("child", dist.Normal(parent, 1), obs=child_data)


def masked_guide(child_data, parent_data, parent_mask):
    with pyro.plate("plate", plate_size):
        parent_dist = dist.Categorical(prior).expand([plate_size])
        parent_1 = parent_data.float()
        with poutine.mask(mask=~parent_mask):
            parent_0 = pyro.sample("parent_0", parent_dist).float()  # no obs
        parent = pyro.deterministic("parent", torch.where(parent_mask, parent_1, parent_0))


child_data = torch.ones(plate_size, dtype=torch.float)

svi = SVI(basic_model, basic_guide, Adam({}), loss=Trace_ELBO())
print(svi.step(child_data))

parent_data = torch.ones(plate_size, dtype=torch.float)
parent_data[0] = np.nan
parent_mask = ~torch.isnan(parent_data)

# replace nans with valid dummy value now that we have the mask
parent_data[~parent_mask] = 0

svi = SVI(masked_model, masked_guide, Adam({}), loss=Trace_ELBO())
print(svi.step(child_data, parent_data, parent_mask))

1 Like