Call for Contributors: Probabilistic Programming & Bayesian Inference in ONNX

Hi everyone,

I wanted to share an initiative that’s getting started and invite discussion and participation from the Pyro and NumPyro communities.

We’re working within the ONNX ecosystem on a proposal to support probabilistic programming and Bayesian inference as first-class workloads in ONNX. The aim is to define a standardized set of ONNX operators and runtime semantics that allow probabilistic models—including Pyro and NumPyro models—to be exported, executed, and optimized portably across frameworks and hardware.

Pyro and NumPyro are central reference points for this work, particularly around:

  • Effect handlers and tracing semantics

  • Vectorized MCMC (HMC/NUTS)

  • JAX-based execution and SPMD-friendly model structure

  • Stateless, explicit RNG handling

What we’re trying to build

At a high level, we’re exploring how ONNX could act as a compute and deployment backend for probabilistic programs, in the same way it already does for neural networks.

Concrete areas of focus include:

  • A probabilistic operator domain in ONNX:

    • Distributions, log-probability evaluation, factors

    • Bijectors / transforms for constrained parameters

  • Stateless, splittable RNG semantics, inspired by JAX and NumPyro

  • Special mathematical functions needed for stable log-density computation

  • Inference operators and building blocks:

    • Laplace, Pathfinder, INLA

    • Metropolis, Gibbs, Slice

    • HMC and NUTS (with FSM-style control flow)

    • Sequential Monte Carlo (SMC)

  • ONNX Runtime integration and execution-provider semantics

  • Exporter paths for probabilistic programming frameworks:

    • Pyro and NumPyro as priority Python targets

    • Alignment with JAX-based tooling and vectorized execution

The goal is not to reimplement Pyro or NumPyro inside ONNX, but to:

  • Capture the log-joint semantics and inference primitives in a portable IR

  • Allow Pyro/NumPyro models to target ONNX for deployment and acceleration

  • Enable hardware-agnostic execution (CPU, GPU, edge, accelerators) without bespoke kernels per framework

Why this might matter for Pyro / NumPyro users

  • A portable representation for probabilistic models beyond Python/JAX runtimes

  • Cleaner paths to production deployment for MCMC-based models

  • Shared operator semantics across Pyro, NumPyro, Stan, PyMC, TFP, etc.

  • Better alignment with accelerator-friendly execution (vectorized chains, SPMD)

  • An opportunity to influence how ONNX handles randomness, control flow, and probabilistic structure—rather than adapting to it later

Open questions we’re actively discussing

  • How far inference algorithms (e.g. NUTS) should live inside ONNX vs. in the PPL

  • How to represent dynamic control flow without breaking portability

  • How to preserve NumPyro/JAX RNG semantics exactly

  • What a “minimum viable” probabilistic opset should look like

  • How exporters should map Pyro traces or NumPyro JAX graphs into ONNX

How to get involved

We’re forming working groups around:

  • RNG semantics and operator specification

  • Distribution, bijector, and special-function catalogs

  • Inference operators and control-flow patterns

  • Exporter design for Pyro and NumPyro

If you’re interested in contributing, reviewing specs, or just sharing feedback from a Pyro/NumPyro perspective, we’d really appreciate your input. Even critical feedback about what won’t work is extremely valuable at this stage.

Feel free to reply here, ask questions, or reach out directly if you’d like to be looped into follow-up discussions.

Best,
Brian

2 Likes

Hi Brian - please keep me in the loop for follow-up discussions. As an avid Pyro user who will soon be required to engage almost exclusively with ONNX I definitely have a stake in enabling compatibility between the two.

1 Like

okay awesome! Can you message me, just send me your linkedin profile and an email address.