Obtain conditional probabilities for discrete variables

You could use TraceEnum_ELBO.compute_marginals() with a model conditioned on your observations.

@config_enumerate
def model(x1=None, x2=None):
    z_logits = pyro.param("p_z", torch.randn(2))
    x1_logits = pyro.param("p_x1", torch.randn(2, 3))
    x2_logits = pyro.param("p_x2", torch.randn(2, 4))
    z = pyro.sample("z", dist.Categorical(logits=z_logits))
    x1 = pyro.sample("x1", dist.Categorical(logits=x1_logits[z]),
                     obs=x1)
    x2 = pyro.sample("x2", dist.Categorical(logits=x2_logits[z]),
                     obs=x2)
    

def guide(**kwargs):
    pass

elbo = TraceEnum_ELBO()
conditional_marginals = elbo.compute_marginals(model, guide, x2=torch.tensor(1.))
p_x1_1 = conditional_marginals["x1"].log_prob(torch.tensor(1)).exp()
print(p_x1_1)
# tensor(0.3103, grad_fn=<ExpBackward>)
2 Likes