Is Pyro the right tool for general graphical models?

Hi everyone,

I am currently interested in performing variational inference on large-scale real-valued graphical models (dynamic Bayesian networks for instance, with tens of thousands of nodes). While I am impressed by the width and depth of the Pyro library, I wonder if it is the right tool for the job.

Indeed, the graphical models that are discussed in the tutorials (mixtures, HMMs, neural networks) have simple “layered” structures that work well with vectorized plate statements. For general graphs, I feel like the only way is to pyro.sample each variable by hand from its parents in an endless for loop, thus missing out on the benefits of the Pytorch backend.

Does this problem ring a bell for anyone?
Thanks in advance
Giom

2 Likes

you are right to worry that, with a very large number of unstructured nodes (with no opportunities for plate structure and the like), the pytorch/python overhead might be considerable. in particular pytorch is not optimized for dealing with large numbers of small tensors.

we are also working on a numpy/jax backend which could be considerably fast for problems of your sort. although, depending on what exactly you have in mind, variational inference for such large bayesian networks may run into more general (not framework specific) algorithmic problems.

Thanks for your answer!
Indeed, if I want to enjoy Pyro’s superpowers I have to find a way to cheat and vectorize my code.

For those interested in similar problems, since I am dealing with a dynamic Bayesian network, I think it may work to vectorize the state of every node in a given time slice, and hand-code the transition from t to t+1 (I got the idea from the tutorial on deep Markov models). In the transition code, I can then transform the edge structure of my conditional probability distributions into matrix operations.

You’ll find a sample (pseudo-)code below for linear Gaussian conditionals. I’m not yet sure it scales well, maybe I will have to add amortization to it.

class Transition(nn.Module):
    def __init__(self, t):
        super(Transition, self).__init__()
        self.weights = edge_weights[t]
        self.sigma = sigmas[t+1]
        self.mu = mus[t+1]

    def forward(self, x):
        return torch.mv(self.weights, x) + self.mu, self.sigma

transitions = [Transition(t) for t in range(T-1)]

def vectorized_DBN_model():
    z = torch.empty((T, V))
    with pyro.plate("time_{}".format(0), E):
        z[0] = pyro.sample(
            "z_{}".format(0),
            dist.Normal(loc=mus[0], scale=sigmas[0])
        )
    for t in pyro.markov(range(T-1)):
        loc, scale = transitions[t].forward(z[t])
        with pyro.plate("time_{}".format(t+1), E):
            z[t+1] = pyro.sample(
                "z_{}".format(t+1),
                dist.Normal(loc=loc, scale=scale)
            )
    return z

Bye!
Giom

Hi again,

As I said above, the generative model can be written without trouble. However, now I want to train the parameters of my DBN by declaring them as pyro.param. While it works fine with mu and sigma, the problem lies in weights: this is a sparse matrix N*N whose structure is dictated by the adjacency matrix of my graph, and I don’t want to train N^2 parameters if I can avoid it.

Is there an equivalent of poutine.mask for parameters, so that I can specify which of them to update while keeping vectorized code?
Otherwise, I was thinking of declaring weights as a full matrix variable and putting Laplace priors to ensure sparsity, but it is much less satisfying…

Thanks in advance
Giom

the simplest thing to do is probably to use a fully parameterized NxN matrix but then, after each gradient step, to explicitly zero out unwanted entries. something like

pyro.param("myparam").data = my_zero_one_mask * pyro.param("myparam").data

(although it might be more complicated due to parameter constraints?)

Thanks for the suggestion!
This actually gave me another idea, which is to zero out unwanted parameters every time I use them, instead of projecting the gradient. I think it is actually better but I’m not sure, since I do not fully understand the mechanics of automatic differentiation.