Hi @dreamerlzl, I am not sure if my answer is helpful when you have to use some utilities to compute such conditional probability. The benefit is it applies to more general models.
import torch
import pyro
import pyro.distributions as dist
@pyro.infer.config_enumerate
def model(z=None, p1=0.3, p2=0.8, p3=0.4):
x1 = pyro.sample('x1', dist.Bernoulli(p1))
x2 = pyro.sample('x2', dist.Bernoulli(p2))
pyro.sample('z', dist.Bernoulli((x1.bool() & x2.bool()) * p3), obs=z)
p_z_x1 = pyro.do(model, data={'x1': torch.tensor(1.)})
p_z_x1_enum = pyro.poutine.enum(p_z_x1, first_available_dim=-1)
trace = pyro.poutine.trace(p_z_x1_enum).get_trace(z=torch.tensor(0.))
log_prob_evaluate = pyro.infer.mcmc.util.TraceEinsumEvaluator(
trace, has_enumerable_sites=True, max_plate_nesting=1)
print("p(z=0|x1=1):", log_prob_evaluate.log_prob(trace).exp())
trace1 = pyro.poutine.trace(p_z_x1_enum).get_trace(z=torch.tensor(1.))
print("p(z=1|x1=1):", log_prob_evaluate.log_prob(trace1).exp())
which outputs
p(z=0|x1=1): tensor(0.6800)
p(z=1|x1=1): tensor(0.3200)
Here are some notes that might make the implementation clearer:
- I used
pyro.do
to compute conditional probability p(z|x1=1). If you want to compute joint probability p(z,x1=1), you just replace it withpyro.condition
. -
config_enumerate
is used to say that we want to enumerate the discrete sites, instead of drawing a single sample from those sites. Alternatively, you can add a keywordinfer={'enumerate': 'parallel'}
to each of those sites (this is more flexible because you can control which sites you want to enumerate). We wrap the model bypoutine.enum
to actually do such enumerate job (otherwise, thoseinfer
keywords will be ignored). -
TraceEinsumEvaluator
is used to compute the joint log_prob of a model trace. Alternatively, I think you can useTraceEnum_ELBO
as in this answer
Please let me know if there is anything not clear to you.