Obtain conditional probabilities for discrete variables

Say I have some model where a latent variable, let’s say Z, influences two discrete observed variables, say X1 and X2. My Pyro programme has learned parameter values for the influences of Z on both X variables, and has done so with a SVI approach, so I have a model and a guide function.

Is there an easy, straight-forward way of obtaining the conditional probabilities of one X variable on the other, for instance p( X1 = 1 | X2 = 1 )?

1 Like

You could use TraceEnum_ELBO.compute_marginals() with a model conditioned on your observations.

@config_enumerate
def model(x1=None, x2=None):
    z_logits = pyro.param("p_z", torch.randn(2))
    x1_logits = pyro.param("p_x1", torch.randn(2, 3))
    x2_logits = pyro.param("p_x2", torch.randn(2, 4))
    z = pyro.sample("z", dist.Categorical(logits=z_logits))
    x1 = pyro.sample("x1", dist.Categorical(logits=x1_logits[z]),
                     obs=x1)
    x2 = pyro.sample("x2", dist.Categorical(logits=x2_logits[z]),
                     obs=x2)
    

def guide(**kwargs):
    pass

elbo = TraceEnum_ELBO()
conditional_marginals = elbo.compute_marginals(model, guide, x2=torch.tensor(1.))
p_x1_1 = conditional_marginals["x1"].log_prob(torch.tensor(1)).exp()
print(p_x1_1)
# tensor(0.3103, grad_fn=<ExpBackward>)
1 Like

@fritzo thanks for your reply. Quick follow-up question:

Would this also work / exist for TraceGraph_ELBO? I could see how this would depend on the enumeration option.

If z is not a discrete enumerated variable, you’ll need to implement a guide that can infer z from partial observations, i.e. of x2 but not x1. Then you can simply trace the guide and replay the model.

1 Like