BTW what does cpd stand for?
Sorry, conditional probability distribution, so the p(a)
and p(b|a)
.
Thanks for the pyronic cleanup. I realized I was missing the config_enumerate
(either as a decorator for model
or wrapping the input to infer_discrete
). I usually use the latter, so when the GMM example above used the former I missed it. Working MWE below. Thanks!
MWE
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import infer_discrete, config_enumerate
from pyro import poutine
import pyro.optim
pyro.enable_validation(True)
num_b = 100
a_cpd = torch.tensor(0.9)
b_cpd = torch.tensor([.01, .99])
# first half of a and b are 0, second half are 1
data = {}
data['a'] = torch.zeros([2, 1])
data['b'] = torch.zeros([2, num_b])
data['a'][1, 0] = 1
data['b'][1, :] = 1
@config_enumerate
def model(a=None, b=None):
with pyro.plate("a_plate", size=2, dim=-2):
a = pyro.sample("a", dist.Bernoulli(a_cpd), obs=a)
with pyro.plate("b_plate", size=num_b, dim=-1):
pyro.sample("b", dist.Bernoulli(b_cpd[a.long()]), obs=b)
inferred_model = infer_discrete(model, temperature=0, first_available_dim=-3)
for target_var in ["a", "b"]:
kwargs = {"a": data["a"].float()} if target_var == "b" else {"b": data["b"].float()}
trace = poutine.trace(inferred_model).get_trace(**kwargs)
if target_var == "a":
print('a pred:', [float(f'{i:.2f}') for i in trace.nodes[target_var]["value"].flatten()])
elif target_var == "b":
print('1st half b pred:', f'{trace.nodes[target_var]["value"][0, :].mean():.2f}')
print('2nd half b pred:', f'{trace.nodes[target_var]["value"][1, :].mean():.2f}')