MAP Prediction

BTW what does cpd stand for?

Sorry, conditional probability distribution, so the p(a) and p(b|a).

Thanks for the pyronic cleanup. I realized I was missing the config_enumerate (either as a decorator for model or wrapping the input to infer_discrete). I usually use the latter, so when the GMM example above used the former I missed it. Working MWE below. Thanks!

MWE
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import infer_discrete, config_enumerate
from pyro import poutine
import pyro.optim

pyro.enable_validation(True)

num_b = 100
a_cpd = torch.tensor(0.9)
b_cpd = torch.tensor([.01, .99])

# first half of a and b are 0, second half are 1
data = {}
data['a'] = torch.zeros([2, 1])
data['b'] = torch.zeros([2, num_b])
data['a'][1, 0] = 1
data['b'][1, :] = 1

@config_enumerate
def model(a=None, b=None):

    with pyro.plate("a_plate", size=2, dim=-2):
        a = pyro.sample("a", dist.Bernoulli(a_cpd), obs=a)
        with pyro.plate("b_plate", size=num_b, dim=-1):
            pyro.sample("b", dist.Bernoulli(b_cpd[a.long()]), obs=b)


inferred_model = infer_discrete(model, temperature=0, first_available_dim=-3)

for target_var in ["a", "b"]:
    kwargs = {"a": data["a"].float()} if target_var == "b" else {"b": data["b"].float()}
    trace = poutine.trace(inferred_model).get_trace(**kwargs)
    if target_var == "a":
        print('a pred:', [float(f'{i:.2f}') for i in trace.nodes[target_var]["value"].flatten()])
    elif target_var == "b":
        print('1st half b pred:', f'{trace.nodes[target_var]["value"][0, :].mean():.2f}')
        print('2nd half b pred:', f'{trace.nodes[target_var]["value"][1, :].mean():.2f}')
1 Like