 # MAP Prediction

I’m failing to get the MAP prediction via `infer_discrete` fully working for a toy model. It’s `(a) --> (b)`, where `a` is a Bernoulli and `b` is a mixture of `Bernoullis` given the value of `a`. I’m hardcoding the CPDs, so no learning, just prediction for now.

Predicting `b` given `a` works great, but predicting `a` given `b` just seems to return draws from `a`'s hardcoded CPD, as if no information from `b` is getting “passed back up”. Am I conceptually missing something here?

MWE
``````import torch
import pyro
import pyro.distributions as dist
from pyro.infer import infer_discrete
from pyro.infer.autoguide import AutoDelta
from pyro import poutine
import pyro.optim

pyro.enable_validation(True)

num_b = 10
data = {}

# first half of a and b are 0, second half are 1
data['a'] = torch.zeros([2, 1])
data['b'] = torch.zeros([2, num_b])
data['a'][1, 0] = 1
data['b'][1, :] = 1

def sample(target_var):
b_plate = pyro.plate('b_plate', size=num_b, dim=-1)
a_plate = pyro.plate('a_plate', size=2, dim=-2)

# sample a
a_cpd = torch.tensor(.9)
with a_plate:
# unobserved as prediction target var, otherwise observed
obs = None if target_var == 'a' else data['a'].float()
a = pyro.sample('a',
dist.Bernoulli(a_cpd),
obs=obs)

# sample b
b_cpd = torch.tensor([.01, .99])
with a_plate, b_plate:
# unobserved as prediction target var, otherwise observed
obs = None if target_var == 'b' else data['b'].float()
b = pyro.sample('b',
dist.Bernoulli(b_cpd[a.long()]),
obs=obs)

return a, b

def init_loc_fn(site):
size = (2, 1) if site['name'] == 'a' else (1, num_b)

# predict a given b and b given a
global_guide = AutoDelta(poutine.block(sample), init_loc_fn=init_loc_fn)
for target_var in ['a', 'b']:
guide_trace = poutine.trace(global_guide).get_trace(target_var)  # record the globals
trained_model = poutine.replay(sample, trace=guide_trace)  # replay the globals
inferred_model = infer_discrete(trained_model, temperature=0, first_available_dim=-3)
trace = poutine.trace(inferred_model).get_trace(target_var)

print(f'{target_var} prediction:', [float(f'{i:.2f}') for i in trace.nodes[target_var]["value"].flatten()])
``````

Hi @gbernstein,
I’m not sure what’s wrong here. BTW what does cpd stand for?

Note a slightly more pyronic way to write your model (sorry if this doesn’t work, and thanks for creating a minimal working example) might be to use defaults for your model arguments:

``````def model(a=None, b=None):
a_cpd = torch.tensor(0.9)
with pyro.plate("a_plate", size=2, dim=-2):
a = pyro.sample("a", dist.Bernoulli(a_cpd),
obs=a)
with pyro.plate("b_plate", size=num_b, dim=-1):
b = pyro.sample("b", dist.Bernoulli(b_cpd[a.long()]),
obs=b)
return a, b
``````

Since there are no latent variables (again thanks for simplifying) we can simply define

``````inferred_model = infer_discrete(trained_model, temperature=0, first_available_dim=-3)
for target_var in ["a", "b"]:
kwargs = {"a": data["a"].float()} if target_var == "b" else {"b": data["b"].float}
trace = poutine.trace(inferred_model).get_trace(**kwargs)
print(f'{target_var} prediction:', [float(f'{i:.2f}') for i in trace.nodes[target_var]["value"].flatten()])
``````

This should indeed produce conditional samples according to `p(a|b)` and `p(b|a)`.

1 Like

BTW what does cpd stand for?

Sorry, conditional probability distribution, so the `p(a)` and `p(b|a)`.

Thanks for the pyronic cleanup. I realized I was missing the `config_enumerate` (either as a decorator for `model` or wrapping the input to `infer_discrete`). I usually use the latter, so when the GMM example above used the former I missed it. Working MWE below. Thanks!

MWE
``````import torch
import pyro
import pyro.distributions as dist
from pyro.infer import infer_discrete, config_enumerate
from pyro import poutine
import pyro.optim

pyro.enable_validation(True)

num_b = 100
a_cpd = torch.tensor(0.9)
b_cpd = torch.tensor([.01, .99])

# first half of a and b are 0, second half are 1
data = {}
data['a'] = torch.zeros([2, 1])
data['b'] = torch.zeros([2, num_b])
data['a'][1, 0] = 1
data['b'][1, :] = 1

@config_enumerate
def model(a=None, b=None):

with pyro.plate("a_plate", size=2, dim=-2):
a = pyro.sample("a", dist.Bernoulli(a_cpd), obs=a)
with pyro.plate("b_plate", size=num_b, dim=-1):
pyro.sample("b", dist.Bernoulli(b_cpd[a.long()]), obs=b)

inferred_model = infer_discrete(model, temperature=0, first_available_dim=-3)

for target_var in ["a", "b"]:
kwargs = {"a": data["a"].float()} if target_var == "b" else {"b": data["b"].float()}
trace = poutine.trace(inferred_model).get_trace(**kwargs)
if target_var == "a":
print('a pred:', [float(f'{i:.2f}') for i in trace.nodes[target_var]["value"].flatten()])
elif target_var == "b":
print('1st half b pred:', f'{trace.nodes[target_var]["value"][0, :].mean():.2f}')
print('2nd half b pred:', f'{trace.nodes[target_var]["value"][1, :].mean():.2f}')
``````
1 Like

@gbernstein good debugging! We’ll add a warning in that case, similar to the warning in `TraceEnum_ELBO`. Hopefully future users won’t need to debug 1 Like