When some observations are masked, a new site with the _unobserved
tag is added to the model. This usually only triggers a warning, but sometimes it also causes Pyro to crash for no obvious reason.
For instance, take the tutorial on Dirichlet Process Mixture Models. If you specify a mask in the model, the code fails, even if the mask specifies that no observations are missing (Pyro 1.8.5).
def model(data):
with pyro.plate("beta_plate", T-1):
beta = pyro.sample("beta", Beta(1, alpha))
with pyro.plate("mu_plate", T):
mu = pyro.sample("mu", MultivariateNormal(torch.zeros(2), 5 * torch.eye(2)))
with pyro.plate("data", N):
z = pyro.sample("z", Categorical(mix_weights(beta)))
pyro.sample("obs", MultivariateNormal(mu[z], torch.eye(2)), obs=data,
# Specify a mask with no missing observations
obs_mask=torch.ones(N, dtype=torch.bool))
Adding the mask as shown above triggers the error message below.
(...)
log_r_term = log_r_term - guide_trace.nodes[name]["log_prob"]
KeyError: 'obs_unobserved'
The issue is in the function _compute_log_r
from the file pyro/infer/trace_elbo.py
, where guide_trace.nodes
does not contain the matching site called obs_unobserved
.
The example from the tutorial suggests that this is a bug, but I am not too sure why the function _compute_log_r
is not always called during inference. So, my questions are:
- Do we always have to manually create a site for the masked observations in the guide?
- Or is the proper way to compute
log_r
to just ignore the samples with the_unobserved
tag?
Any insight as to when _compute_log_r
is called would be the cherry on the cake.