Pyro for Directed Bayesian Network Inference

Hi, I haven’t run your code but your model seems fine. I would suggest vectorizing over observations and using indexing instead of Python conditionals so that it is compatible with Pyro’s enumeration machinery for discrete variables:

@pyro.infer.config_enumerate
def model(N):
    probs_a = pyro.param("probs_a", ...)
    probs_b = pyro.param("probs_b", ...)
    probs_c = pyro.param("probs_c", ...)
    with pyro.plate("data", N):
        a = pyro.sample("a", dist.Categorical(probs=probs_a))
        b = pyro.sample("b", dist.Categorical(probs=probs_b[..., a]))
        c = pyro.sample("c", dist.Categorical(probs=probs_c[..., a]))
        return a, b, c

I’ve used pyro.params for the conditional probability tables to keep the code short and avoid needing a guide, but you could be Bayesian about them as in your original model. You can use SVI with any ELBO for parameter learning since none of the variables here are latent:

conditioned_model = pyro.condition(model, data={"a": data_a, "b": data_b, "c": data_c})
optim = pyro.optim.SGD()
svi = pyro.infer.SVI(conditioned_model, lambda N: None, pyro.infer.TraceEnum_ELBO(), optim)
for step in range(num_steps):
    svi.step(N)

You can perform inference (i.e. computing p(A | C = 0)) with pyro.infer.TraceEnum_ELBO.compute_marginals:

conditioned_model = pyro.condition(model, data={"b": 0})
marginal_a = TraceEnum_ELBO().compute_marginals(
    conditioned_model, lambda N: None, 1)["a"]
argmaxp = torch.argmax(marginal_a.probs.squeeze())

You might also find it easier or more intuitive to work directly with the conditional probability tables using pyro.ops.contract.einsum for inference (this is what TraceEnum_ELBO does under the hood):

from pyro.ops.contract import einsum

logp_b_0 = dist.Categorical(logits=logp_b).log_prob(0)
logp_a_marginal = einsum("a,a,ac->a", logp_a, logp_b_0, logp_c,
                         backend="pyro.ops.einsum.torch_log")
argmaxlp = torch.argmax(logp)
2 Likes