Hi. I could have sworn I’ve previously gotten poutine masking working to ignore nan observations, but I can’t do so in this super simple toy example. The parameters for the first 9 parent instances are all learned correctly, but the last is learned as nan. My intention was for the the single nan child observation to be masked out and for the last parent instance to be learned only using the other 19 non-nan child observations.
Am I missing something obvious? Did something change/break in a recent release?
MWE
import torch
import pyro
import pyro.distributions as dist
pyro.enable_validation(False)
num_parent = 10
num_child = 20
def model(obs):
with pyro.plate('parent_plate', size=num_parent, dim=-1):
parent = pyro.sample('parent', dist.Normal(0.5, 0.1))
with pyro.plate('child_plate', size=num_child, dim=-2):
with pyro.poutine.mask(mask=~torch.isnan(obs)):
pyro.sample(f'child', dist.Normal(parent, 0.01), obs=obs)
def guide(obs):
with pyro.plate('parent_plate', size=num_parent, dim=-1):
mean = pyro.param('parent_mean', torch.full((num_parent,), .5))
pyro.sample('parent', dist.Normal(mean, 0.1))
def main():
obs = torch.linspace(0, 1, num_parent).expand((num_child, num_parent))
obs[-1, -1] = float('nan')
svi = pyro.infer.SVI(
model,
guide,
pyro.optim.Adam({}),
loss=pyro.infer.Trace_ELBO(),
)
for _ in range(1000):
svi.step(obs)
param_store = pyro.get_param_store()
mean = param_store['parent_mean']
parent_pred = [round(float(i), 2) for i in mean]
print('parent prediction:', parent_pred)
if __name__ == '__main__':
main()