MAP Prediction

I’m failing to get the MAP prediction via infer_discrete fully working for a toy model. It’s (a) --> (b), where a is a Bernoulli and b is a mixture of Bernoullis given the value of a. I’m hardcoding the CPDs, so no learning, just prediction for now.

Predicting b given a works great, but predicting a given b just seems to return draws from a's hardcoded CPD, as if no information from b is getting “passed back up”. Am I conceptually missing something here?

MWE
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import infer_discrete
from pyro.infer.autoguide import AutoDelta
from pyro import poutine
import pyro.optim

pyro.enable_validation(True)

num_b = 10
data = {}

# first half of a and b are 0, second half are 1
data['a'] = torch.zeros([2, 1])
data['b'] = torch.zeros([2, num_b])
data['a'][1, 0] = 1
data['b'][1, :] = 1


def sample(target_var):
    b_plate = pyro.plate('b_plate', size=num_b, dim=-1)
    a_plate = pyro.plate('a_plate', size=2, dim=-2)

    # sample a
    a_cpd = torch.tensor(.9)
    with a_plate:
        # unobserved as prediction target var, otherwise observed
        obs = None if target_var == 'a' else data['a'].float()
        a = pyro.sample('a',
                        dist.Bernoulli(a_cpd),
                        obs=obs)

    # sample b
    b_cpd = torch.tensor([.01, .99])
    with a_plate, b_plate:
        # unobserved as prediction target var, otherwise observed
        obs = None if target_var == 'b' else data['b'].float()
        b = pyro.sample('b',
                        dist.Bernoulli(b_cpd[a.long()]),
                        obs=obs)

    return a, b


def init_loc_fn(site):
    size = (2, 1) if site['name'] == 'a' else (1, num_b)
    return torch.full(size, 0)


# predict a given b and b given a
global_guide = AutoDelta(poutine.block(sample), init_loc_fn=init_loc_fn)
for target_var in ['a', 'b']:
    guide_trace = poutine.trace(global_guide).get_trace(target_var)  # record the globals
    trained_model = poutine.replay(sample, trace=guide_trace)  # replay the globals
    inferred_model = infer_discrete(trained_model, temperature=0, first_available_dim=-3)
    trace = poutine.trace(inferred_model).get_trace(target_var)

    print(f'{target_var} prediction:', [float(f'{i:.2f}') for i in trace.nodes[target_var]["value"].flatten()])

Hi @gbernstein,
I’m not sure what’s wrong here. BTW what does cpd stand for?

Note a slightly more pyronic way to write your model (sorry if this doesn’t work, and thanks for creating a minimal working example) might be to use defaults for your model arguments:

def model(a=None, b=None):
    a_cpd = torch.tensor(0.9)
    with pyro.plate("a_plate", size=2, dim=-2):
        a = pyro.sample("a", dist.Bernoulli(a_cpd),
                        obs=a)
        with pyro.plate("b_plate", size=num_b, dim=-1):
            b = pyro.sample("b", dist.Bernoulli(b_cpd[a.long()]),
                            obs=b)
    return a, b

Since there are no latent variables (again thanks for simplifying) we can simply define

inferred_model = infer_discrete(trained_model, temperature=0, first_available_dim=-3)
for target_var in ["a", "b"]:
    kwargs = {"a": data["a"].float()} if target_var == "b" else {"b": data["b"].float}
    trace = poutine.trace(inferred_model).get_trace(**kwargs)
    print(f'{target_var} prediction:', [float(f'{i:.2f}') for i in trace.nodes[target_var]["value"].flatten()])

This should indeed produce conditional samples according to p(a|b) and p(b|a).

1 Like

BTW what does cpd stand for?

Sorry, conditional probability distribution, so the p(a) and p(b|a).

Thanks for the pyronic cleanup. I realized I was missing the config_enumerate (either as a decorator for model or wrapping the input to infer_discrete). I usually use the latter, so when the GMM example above used the former I missed it. Working MWE below. Thanks!

MWE
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import infer_discrete, config_enumerate
from pyro import poutine
import pyro.optim

pyro.enable_validation(True)

num_b = 100
a_cpd = torch.tensor(0.9)
b_cpd = torch.tensor([.01, .99])

# first half of a and b are 0, second half are 1
data = {}
data['a'] = torch.zeros([2, 1])
data['b'] = torch.zeros([2, num_b])
data['a'][1, 0] = 1
data['b'][1, :] = 1

@config_enumerate
def model(a=None, b=None):

    with pyro.plate("a_plate", size=2, dim=-2):
        a = pyro.sample("a", dist.Bernoulli(a_cpd), obs=a)
        with pyro.plate("b_plate", size=num_b, dim=-1):
            pyro.sample("b", dist.Bernoulli(b_cpd[a.long()]), obs=b)


inferred_model = infer_discrete(model, temperature=0, first_available_dim=-3)

for target_var in ["a", "b"]:
    kwargs = {"a": data["a"].float()} if target_var == "b" else {"b": data["b"].float()}
    trace = poutine.trace(inferred_model).get_trace(**kwargs)
    if target_var == "a":
        print('a pred:', [float(f'{i:.2f}') for i in trace.nodes[target_var]["value"].flatten()])
    elif target_var == "b":
        print('1st half b pred:', f'{trace.nodes[target_var]["value"][0, :].mean():.2f}')
        print('2nd half b pred:', f'{trace.nodes[target_var]["value"][1, :].mean():.2f}')
1 Like

@gbernstein good debugging! We’ll add a warning in that case, similar to the warning in TraceEnum_ELBO. Hopefully future users won’t need to debug :smile:

1 Like