Is Pyro the right tool for general graphical models?


Hi everyone,

I am currently interested in performing variational inference on large-scale real-valued graphical models (dynamic Bayesian networks for instance, with tens of thousands of nodes). While I am impressed by the width and depth of the Pyro library, I wonder if it is the right tool for the job.

Indeed, the graphical models that are discussed in the tutorials (mixtures, HMMs, neural networks) have simple “layered” structures that work well with vectorized plate statements. For general graphs, I feel like the only way is to pyro.sample each variable by hand from its parents in an endless for loop, thus missing out on the benefits of the Pytorch backend.

Does this problem ring a bell for anyone?
Thanks in advance


you are right to worry that, with a very large number of unstructured nodes (with no opportunities for plate structure and the like), the pytorch/python overhead might be considerable. in particular pytorch is not optimized for dealing with large numbers of small tensors.

we are also working on a numpy/jax backend which could be considerably fast for problems of your sort. although, depending on what exactly you have in mind, variational inference for such large bayesian networks may run into more general (not framework specific) algorithmic problems.


Thanks for your answer!
Indeed, if I want to enjoy Pyro’s superpowers I have to find a way to cheat and vectorize my code.

For those interested in similar problems, since I am dealing with a dynamic Bayesian network, I think it may work to vectorize the state of every node in a given time slice, and hand-code the transition from t to t+1 (I got the idea from the tutorial on deep Markov models). In the transition code, I can then transform the edge structure of my conditional probability distributions into matrix operations.

You’ll find a sample (pseudo-)code below for linear Gaussian conditionals. I’m not yet sure it scales well, maybe I will have to add amortization to it.

class Transition(nn.Module):
    def __init__(self, t):
        super(Transition, self).__init__()
        self.weights = edge_weights[t]
        self.sigma = sigmas[t+1] = mus[t+1]

    def forward(self, x):
        return, x) +, self.sigma

transitions = [Transition(t) for t in range(T-1)]

def vectorized_DBN_model():
    z = torch.empty((T, V))
    with pyro.plate("time_{}".format(0), E):
        z[0] = pyro.sample(
            dist.Normal(loc=mus[0], scale=sigmas[0])
    for t in pyro.markov(range(T-1)):
        loc, scale = transitions[t].forward(z[t])
        with pyro.plate("time_{}".format(t+1), E):
            z[t+1] = pyro.sample(
                dist.Normal(loc=loc, scale=scale)
    return z