A followup warning for future post visitors. This solution requires a very small tweak mentioned in Fritzo’s previous posts (linked above in my first post) to get it to work when changing the parent distribution to Categorical: **Before sampling the observed parent data in the model, you must replace the nan values.**

Nitty-gritty details: When calculating the loss, the log-probability is calculated for each observation, and then only `poutine.mask`

ed log-probabilities are used. This is fine for a Normal variable because the `torch.distributions.normal.log_prob`

return statement can gracefully handle `nan`

s. The `torch.distributions.categorical.log_prob`

return statement, however, cannot gracefully handle `nan`

s because it tries to use it as an index in `torch.gather`

. The presence of `nan`

s in the data will indeed raise an exception if Pyro validation is enabled.

Since the corresponding log-probability entry will be masked out anyways, the solution is to replace the `nan`

with any valid dummy category to make `gather`

happy. We can confirm the choice of dummy value doesn’t affect anything by setting the random seed and printing the loss when replacing `nan`

s with different valid dummy values.

##
MWE

```
import pyro
from pyro import poutine
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
import numpy as np
plate_size = 10
prior = torch.tensor([.2, .8])
def basic_model(child_data):
with pyro.plate("plate", plate_size):
parent_dist = dist.Categorical(prior).expand([plate_size])
parent = pyro.sample("parent", parent_dist)
pyro.sample("child", dist.Normal(parent.float(), 1),
obs=child_data)
def basic_guide(child_data):
with pyro.plate("plate", plate_size):
parent_dist = dist.Categorical(prior).expand([plate_size])
parent = pyro.sample("parent", parent_dist)
def masked_model(child_data, parent_data, parent_mask):
with pyro.plate("plate", plate_size):
parent_dist = dist.Categorical(prior).expand([plate_size])
with poutine.mask(mask=parent_mask):
parent_1 = pyro.sample("parent_1", parent_dist, obs=parent_data).float()
with poutine.mask(mask=~parent_mask):
parent_0 = pyro.sample("parent_0", parent_dist).float() # no obs
parent = pyro.deterministic("parent", torch.where(parent_mask, parent_1, parent_0))
pyro.sample("child", dist.Normal(parent, 1), obs=child_data)
def masked_guide(child_data, parent_data, parent_mask):
with pyro.plate("plate", plate_size):
parent_dist = dist.Categorical(prior).expand([plate_size])
parent_1 = parent_data.float()
with poutine.mask(mask=~parent_mask):
parent_0 = pyro.sample("parent_0", parent_dist).float() # no obs
parent = pyro.deterministic("parent", torch.where(parent_mask, parent_1, parent_0))
child_data = torch.ones(plate_size, dtype=torch.float)
svi = SVI(basic_model, basic_guide, Adam({}), loss=Trace_ELBO())
print(svi.step(child_data))
parent_data = torch.ones(plate_size, dtype=torch.float)
parent_data[0] = np.nan
parent_mask = ~torch.isnan(parent_data)
# replace nans with valid dummy value now that we have the mask
parent_data[~parent_mask] = 0
svi = SVI(masked_model, masked_guide, Adam({}), loss=Trace_ELBO())
print(svi.step(child_data, parent_data, parent_mask))
```