Pyro trace -- how is it used?


#1

Hi Pyro devs, would you mind explaining Pyro programs from the perspective of the trace?

That is, in probabilistic programming systems (PPS) one can think of inference as sampling from the a posteriori distribution of execution traces arising from stochastic programs constrained to reflect observed data. This implies there must be an explicit representation of the execution trace. In Pyro this is with the Trace (and TraceMessenger?) objects, but it’s not entirely clear how the trace is invoked and utilized. For example, importance sampling inference creates a few trace instances (here), but how does the program operate on these trace objects and how do they relate to a global execution trace?

The main info I’m looking for is (i) where is a global execution trace represented and interacted with by Trace objects, (ii) where and how is the trace modified to e.g. condition on observed data, and (iii) what criteria/choices went into the design of the Pyro trace?

Thanks in advance!

Cheers,
Alex


#2

Hi, you’ve already found the Trace object - there is no global execution trace in Pyro, and traces aren’t used to control execution. Instead, Pyro is built on Poutine, a library of composable effect handlers that modify the behavior of Pyro programs.

To trace the execution of a stochastic function in Pyro you can wrap it with poutine.trace: traced_fn = poutine.trace(fn), which returns a callable with the same input and return types as fn. Each time you call traced_fn(*args, **kwargs), this object stores in its trace attribute a new Trace data structure containing all the pyro.sample and pyro.param statements encountered while executing fn(*args, **kwargs). traced_fn.get_trace(*args, **kwargs) is a convenient helper that calls traced_fn with those arguments and returns the resulting Trace object.

To condition a sample site on data, you can use poutine.condition (aliased to pyro.condition): constrained_fn = pyro.condition(fn, data={"x": x_data}). This returns a function with the same input and return types as fn, but with sample site "x" constrained to be equal to x_data. constrained_fn can then be used with poutine.trace as above, where sample site "x" will be marked as constrained. See the language intro tutorial for more details.

You can read more about the design principles of Pyro in the blog post announcing its release.


#3

Thank you for the quick reply @eb8680! These notes are helpful. I was unaware of the aliases defined in poutine/init.py. And I had read that post when Pyro was first released but the design notes make more sense now that I’m familiar with the code.