Wrapping numpyro for inference in directed PGMs

Hi all! I’m currently working on a lib that’s laser-focused on constructing + doing inference on Bayesian Networks (see here), and I would love to use numpyro in order to leverage the speed bonuses from XLA compilation found in jax.

My plan is to write out a pyro model based on specification of a PGM using graph-type language (nodes, edges, etc), and then delegate methods like

To begin I’m only considering discrete models, so I’ve been shopping around the various posts on inference using enumeration (config_enumerate and infer_discrete etc.), but I’m unsure what level of abstraction I’m looking for here:

  • If I have access to conditional probability tables, is using opt_einsum enough for inference by variable elim? Or can I use one of the above abstractions to do this in an easier way?
  • Would one need to bother creating junction trees in the general case of inference in non-tree-like Bayes nets? Or would I have to somehow manually create a new pyro model based on the junction tree and then delegate to the appropriate inference method?

Thanks for the help! :slight_smile:

1 Like

Hi @phinate, thanks for your interest in NumPyro as a backend for your library.

If you are only interested in exact inference with discrete variables and no plates, then yes, working directly with opt_einsum is feasible. Making use of config_enumerate and infer_discrete might be a little easier and allow the use of plate notation via numpyro.plate as well as other NumPyro inference algorithms in conjunction with variable elimination.

You might also be interested in using our Jax-compatible Funsor library as an intermediate representation corresponding roughly to factor graphs.

Using opt_einsum or NumPyro for variable elimination will implicitly and automatically construct a junction tree. To get an intuition for why this is true, try working out how einsum("a,b,ab->", ...) could be evaluated in terms of np.sum() and elementwise products and how this einsum query corresponds to the marginal likelihood p(C=c) in the Bayesian network p(A) * p(B) * p(C | A, B).

1 Like