[Beginners] Toy examples/tutorials for exact inference on discrete variables?

Hi @dreamerlzl, I am not sure if my answer is helpful when you have to use some utilities to compute such conditional probability. The benefit is it applies to more general models.

import torch
import pyro
import pyro.distributions as dist

@pyro.infer.config_enumerate
def model(z=None, p1=0.3, p2=0.8, p3=0.4):
    x1 = pyro.sample('x1', dist.Bernoulli(p1))
    x2 = pyro.sample('x2', dist.Bernoulli(p2))
    pyro.sample('z', dist.Bernoulli((x1.bool() & x2.bool()) * p3), obs=z)

p_z_x1 = pyro.do(model, data={'x1': torch.tensor(1.)})
p_z_x1_enum = pyro.poutine.enum(p_z_x1, first_available_dim=-1)
trace = pyro.poutine.trace(p_z_x1_enum).get_trace(z=torch.tensor(0.))
log_prob_evaluate = pyro.infer.mcmc.util.TraceEinsumEvaluator(
    trace, has_enumerable_sites=True, max_plate_nesting=1)
print("p(z=0|x1=1):", log_prob_evaluate.log_prob(trace).exp())
trace1 = pyro.poutine.trace(p_z_x1_enum).get_trace(z=torch.tensor(1.))
print("p(z=1|x1=1):", log_prob_evaluate.log_prob(trace1).exp())

which outputs

p(z=0|x1=1): tensor(0.6800)
p(z=1|x1=1): tensor(0.3200)

Here are some notes that might make the implementation clearer:

  • I used pyro.do to compute conditional probability p(z|x1=1). If you want to compute joint probability p(z,x1=1), you just replace it with pyro.condition.
  • config_enumerate is used to say that we want to enumerate the discrete sites, instead of drawing a single sample from those sites. Alternatively, you can add a keyword infer={'enumerate': 'parallel'} to each of those sites (this is more flexible because you can control which sites you want to enumerate). We wrap the model by poutine.enum to actually do such enumerate job (otherwise, those infer keywords will be ignored).
  • TraceEinsumEvaluator is used to compute the joint log_prob of a model trace. Alternatively, I think you can use TraceEnum_ELBO as in this answer

Please let me know if there is anything not clear to you.

2 Likes