Hi, I haven’t run your code but your model seems fine. I would suggest vectorizing over observations and using indexing instead of Python conditionals so that it is compatible with Pyro’s enumeration machinery for discrete variables:
@pyro.infer.config_enumerate
def model(N):
probs_a = pyro.param("probs_a", ...)
probs_b = pyro.param("probs_b", ...)
probs_c = pyro.param("probs_c", ...)
with pyro.plate("data", N):
a = pyro.sample("a", dist.Categorical(probs=probs_a))
b = pyro.sample("b", dist.Categorical(probs=probs_b[..., a]))
c = pyro.sample("c", dist.Categorical(probs=probs_c[..., a]))
return a, b, c
I’ve used pyro.param
s for the conditional probability tables to keep the code short and avoid needing a guide, but you could be Bayesian about them as in your original model. You can use SVI with any ELBO for parameter learning since none of the variables here are latent:
conditioned_model = pyro.condition(model, data={"a": data_a, "b": data_b, "c": data_c})
optim = pyro.optim.SGD()
svi = pyro.infer.SVI(conditioned_model, lambda N: None, pyro.infer.TraceEnum_ELBO(), optim)
for step in range(num_steps):
svi.step(N)
You can perform inference (i.e. computing p(A | C = 0)) with pyro.infer.TraceEnum_ELBO.compute_marginals
:
conditioned_model = pyro.condition(model, data={"b": 0})
marginal_a = TraceEnum_ELBO().compute_marginals(
conditioned_model, lambda N: None, 1)["a"]
argmaxp = torch.argmax(marginal_a.probs.squeeze())
You might also find it easier or more intuitive to work directly with the conditional probability tables using pyro.ops.contract.einsum
for inference (this is what TraceEnum_ELBO
does under the hood):
from pyro.ops.contract import einsum
logp_b_0 = dist.Categorical(logits=logp_b).log_prob(0)
logp_a_marginal = einsum("a,a,ac->a", logp_a, logp_b_0, logp_c,
backend="pyro.ops.einsum.torch_log")
argmaxlp = torch.argmax(logp)