I could be wrong (this is common) but it seems as though pyro.infer.Predictive
doesn’t support dynamic model structure.
def model():
choice = pyro.sample("choice", dist.Categorical(probs=0.5 * torch.ones(2)))
if choice == 0:
# normal model
mu = pyro.sample("mu", dist.Normal(0, 1))
sigma = pyro.sample("sigma", dist.Exponential(1))
symm = pyro.sample("symm", dist.Normal(mu, sigma))
return symm
elif choice == 1:
# exponential model
rate = pyro.sample("rate", dist.Gamma(1, 3))
asym = pyro.sample("asym", dist.Exponential(rate))
return asym
Sampling from the prior predictive distribution:
prior_predictive = pyro.infer.Predictive(model, num_samples=100)
prior_samples = prior_predictive()
This generates an internal KeyError
:
---> 42 collected.append({site: trace.nodes[site]['value'] for site in return_site_shapes})
43
44 if return_trace:
KeyError: 'mu'
This makes sense because the categorical distribution generated a 1 during a run of the model. Fine. For a sanity check, the manual prior predictive distribution is fine:
draws = torch.tensor(list(map(lambda x: model(), range(100)))
gives the expected (strange-looking) prior predictive distribution.
Question:
a) Am I doing something wrong here? There is nothing in the documentation about dynamic vs static model structure with Predictive
.
b) This seems like a pretty easy fix — just put nan
in the dictionary if a KeyError
arises when trying to access a sample node that doesn’t exist. Should I submit a pull request or something?
Thanks in advance and sorry if I’m missing something obvious!