Masked observations and missing sites in guide

When some observations are masked, a new site with the _unobserved tag is added to the model. This usually only triggers a warning, but sometimes it also causes Pyro to crash for no obvious reason.

For instance, take the tutorial on Dirichlet Process Mixture Models. If you specify a mask in the model, the code fails, even if the mask specifies that no observations are missing (Pyro 1.8.5).

def model(data):
    with pyro.plate("beta_plate", T-1):
        beta = pyro.sample("beta", Beta(1, alpha))

    with pyro.plate("mu_plate", T): 
        mu = pyro.sample("mu", MultivariateNormal(torch.zeros(2), 5 * torch.eye(2)))

    with pyro.plate("data", N): 
        z = pyro.sample("z", Categorical(mix_weights(beta)))
        pyro.sample("obs", MultivariateNormal(mu[z], torch.eye(2)), obs=data,
        # Specify a mask with no missing observations
        obs_mask=torch.ones(N, dtype=torch.bool))

Adding the mask as shown above triggers the error message below.

(...)
    log_r_term = log_r_term - guide_trace.nodes[name]["log_prob"]
KeyError: 'obs_unobserved'

The issue is in the function _compute_log_r from the file pyro/infer/trace_elbo.py, where guide_trace.nodes does not contain the matching site called obs_unobserved.

The example from the tutorial suggests that this is a bug, but I am not too sure why the function _compute_log_r is not always called during inference. So, my questions are:

  1. Do we always have to manually create a site for the masked observations in the guide?
  2. Or is the proper way to compute log_r to just ignore the samples with the _unobserved tag?

Any insight as to when _compute_log_r is called would be the cherry on the cake.

Here is what I gathered by looking deeper into the code. It seems that the function _compute_log_r is called only in the Tricky Case of the tutorial on ELBO gradient estimators. In the Easy Case, it is not called at all.

In both the Easy Case and the Tricky Case, missing observations are put into a node with the _unobserved tag. In the Easy Case, those missing observations influence the ELBO only through the model (unless a pyro.sample with matching _unobserved tag is declared in the guide). The model log-likelihood of the missing observations is computed and added to the total; the guide log-likelihood does not contribute.

In the Tricky Case, the function _compute_log_r is eventually called (once), it runs through the model nodes and expects to find the same nodes in the guide. So in the Tricky Case the user has to declare a pyro.sample with matching _unobserved tag in the guide.

I think that the behavior should be the same in both cases, right? Either the missing observations should be discounted from the guide by default, or the user should be forced to sample them in the guide. I will open a pull request about this, but any insight would be welcome.

For those interested in the details, I opened an issue here and a pull request here. There seems to be some minor differences in the way Pyro treats missing observations with the reparametrization trick vs the other cases. This can cause some confusion and unexpected behaviors.

1 Like