Hi @dreamerlzl, I am not sure if my answer is helpful when you have to use some utilities to compute such conditional probability. The benefit is it applies to more general models.
import torch
import pyro
import pyro.distributions as dist
@pyro.infer.config_enumerate
def model(z=None, p1=0.3, p2=0.8, p3=0.4):
x1 = pyro.sample('x1', dist.Bernoulli(p1))
x2 = pyro.sample('x2', dist.Bernoulli(p2))
pyro.sample('z', dist.Bernoulli((x1.bool() & x2.bool()) * p3), obs=z)
p_z_x1 = pyro.do(model, data={'x1': torch.tensor(1.)})
p_z_x1_enum = pyro.poutine.enum(p_z_x1, first_available_dim=-1)
trace = pyro.poutine.trace(p_z_x1_enum).get_trace(z=torch.tensor(0.))
log_prob_evaluate = pyro.infer.mcmc.util.TraceEinsumEvaluator(
trace, has_enumerable_sites=True, max_plate_nesting=1)
print("p(z=0|x1=1):", log_prob_evaluate.log_prob(trace).exp())
trace1 = pyro.poutine.trace(p_z_x1_enum).get_trace(z=torch.tensor(1.))
print("p(z=1|x1=1):", log_prob_evaluate.log_prob(trace1).exp())
which outputs
p(z=0|x1=1): tensor(0.6800)
p(z=1|x1=1): tensor(0.3200)
Here are some notes that might make the implementation clearer:
- I used
pyro.doto compute conditional probability p(z|x1=1). If you want to compute joint probability p(z,x1=1), you just replace it withpyro.condition. -
config_enumerateis used to say that we want to enumerate the discrete sites, instead of drawing a single sample from those sites. Alternatively, you can add a keywordinfer={'enumerate': 'parallel'}to each of those sites (this is more flexible because you can control which sites you want to enumerate). We wrap the model bypoutine.enumto actually do such enumerate job (otherwise, thoseinferkeywords will be ignored). -
TraceEinsumEvaluatoris used to compute the joint log_prob of a model trace. Alternatively, I think you can useTraceEnum_ELBOas in this answer
Please let me know if there is anything not clear to you.