Hi. I could have sworn I’ve previously gotten poutine masking working to ignore nan
observations, but I can’t do so in this super simple toy example. The parameters for the first 9 parent instances are all learned correctly, but the last is learned as nan
. My intention was for the the single nan
child observation to be masked out and for the last parent instance to be learned only using the other 19 non-nan
child observations.
Am I missing something obvious? Did something change/break in a recent release?
MWE
import torch
import pyro
import pyro.distributions as dist
pyro.enable_validation(False)
num_parent = 10
num_child = 20
def model(obs):
with pyro.plate('parent_plate', size=num_parent, dim=-1):
parent = pyro.sample('parent', dist.Normal(0.5, 0.1))
with pyro.plate('child_plate', size=num_child, dim=-2):
with pyro.poutine.mask(mask=~torch.isnan(obs)):
pyro.sample(f'child', dist.Normal(parent, 0.01), obs=obs)
def guide(obs):
with pyro.plate('parent_plate', size=num_parent, dim=-1):
mean = pyro.param('parent_mean', torch.full((num_parent,), .5))
pyro.sample('parent', dist.Normal(mean, 0.1))
def main():
obs = torch.linspace(0, 1, num_parent).expand((num_child, num_parent))
obs[-1, -1] = float('nan')
svi = pyro.infer.SVI(
model,
guide,
pyro.optim.Adam({}),
loss=pyro.infer.Trace_ELBO(),
)
for _ in range(1000):
svi.step(obs)
param_store = pyro.get_param_store()
mean = param_store['parent_mean']
parent_pred = [round(float(i), 2) for i in mean]
print('parent prediction:', parent_pred)
if __name__ == '__main__':
main()