Parallelizing [exact] inference

I’m rather new to Pyro, so I’m still digesting the codebase. While I know exact inference isn’t ideal for computational efficiency, I’m using that (for now) because I’m reproducing some prior research and want to stick as closely as I can to their process.

I’ve been using the Search class specified in examples/rsa/search_inference.py. Given that I have a few hundred/thousand K samples to make, for exact inference, I was wondering if there’s a way to parallelize the search process.

My standard tool for parallelization is dask but given some of their quirks wrt “all tasks must be futures”, I don’t see a way to it without significantly modifying a codebase I don’t fully understand.

So far, multiprocessing.Queue whines about an inability to pickle the EscapeMessenger. I’ve tried multiprocess.Queue, but I seem stuck in a lock. For folks that have parallelized their inference process, do you have any recommendations on next-steps I could take?

Hi @jmuchovej, I’d recommend trying to parallelize by vectorizing if possible. Often you can rewrite a probabilistic program with dynamic control flow to an equivalent program with static control flow and masking. That may require a custom implementation of Search.

:wave: @fritzo, thanks for replying! :slight_smile:

In non-Pyro contexts, I’d have a better idea of what to do here. :sweat_smile: (I’m building an RSA, but the world state is 500 \times 300, so everything balloons because of that.)

Looking at the docs on pyro.sample and pyro.factor (the main Pyro functions I use, atm), I don’t see a way to take N pyro.sample (unless doing a list comprehension). Even there, I don’t quite a see how that would achieve a speed-up, as the main bottleneck is when I get higher up on the speaker/listener stack. Each speaker takes around 30s to compute the world state + utterance combinations.

However, the EscapeMessenger is deemed unpicklable by multiprocessing. (The exact error is AttributeError: Can't pickle local object 'EscapeMessenger._pyro_sample.<locals>.cont'.)

The main difference between examples/rsa/generics.py and this project is that the state space is significantly larger. (The vocabulary is about the same size, from what I can tell.)

I was thinking of mp or dask because I have at least 500 \times 300 \times 3 \times 11 combinations to compute (while caching those computations is great, I must still compute over hundreds of K just at the speaker level, even more so at the top-level listener). The 500 \times 300 is world state, 3 is the # of acceptable distance margins, and 11 is the number of phrases describing the world. There are 3 levels of RSA, a literal listener over states/margin, a speaker over the phrases and literal listener, then a top-level [pragmatic] listener over states/margin and the speaker.

@jmuchovej the implementation of Search used in the Pyro RSA examples is rather wasteful and inefficient and is probably not a great foundation for new research. When I wrote it a couple of years ago my only goal was to make the models in those examples resemble their webPPL originals as closely as possible, which unfortunately came at the expense of scalability beyond the toy settings in the originals.

I would recommend rewriting your model in terms of parallelizable tensor operations (e.g. replacing if-statement sequences with tensor indexing and string literals with integers; see the tensor shape and enumeration tutorials for more details and tips) and reimplementing Search. It’s unlikely that I’ll be able to make these changes myself anytime soon, since they’re pretty far removed from my current day job, but if you’re interested in contributing I’d be happy to help you with getting started and dealing with the more technically difficult bits of inference.

Alternatively, you could try using the original webPPL example code, which should produce identical results. This will be faster than the current Pyro code, especially if you want to experiment with deeper levels of RSA recursion, but much less scalable than the previous suggestion as the model and dataset sizes grow.

If for some reason neither of these suggestions is feasible and you really need to use the Pyro code as-is, then to work around the pickling error you could try using a nonstandard pickle replacement like cloudpickle that supports pickling closures.

1 Like

:wave: @eb8680_2 Thanks for the reply! :slight_smile:

I’m not too sure what you mean by contributing (e.g. adding code to Pyro, examples to the docs, or…?); but I’d definitely appreciate some direction. (I’m more inclined to contribute to example code but might be able to contribute Pyro code as well.) I’m not too keen on using webPPL, for a variety of reasons.

I’ve read both the tensor shape and enumeration tutorials, but I’m still digesting them. Based on my current understanding, though, the advice in both breaks down once handed recursive models. I’m also not entirely sure how to go about reimplementing Search, as the error I’ve been getting appears to be triggered by poutine.queue's internals (calling the EscapeMessenger).

I’m not too sure what you mean by contributing

Sorry, yes, I meant updating Pyro’s examples/rsa/generics.py and examples/rsa/search_inference.py so other users interested in this class of models can benefit from your work.

Based on my current understanding, though, the advice in both breaks down once handed recursive models

The recursion depth in all of the RSA examples in examples/rsa/ is static and does not depend on the output of a pyro.sample statement, so that shouldn’t be an issue.

If you’re up for it, I suggest starting with a pull request in which you fork generics.py into a new file generics_vectorized.py and systematically replace every Python control flow statement in the model functions there with tensor indexing or torch.where operations and every non-tensor value with a tensor. Making these changes is necessary for compatibility with any future vectorized version of Search.

Feel free to open an incomplete draft PR if you get stuck and need help.

1 Like

:wave: Sorry for the delayed pushes, I got sidetracked with some other things. :sweat_smile:

So, I’ve opened a draft PR (I’ve been working out of a notebook, but my progress is also in a similarly named Python file). I’ve [naïvely] tweaked HashingMarginal to convert the 12-element Tensor of sample sites into individual sites. However, pyro.sample(..., wings_prior) doesn’t appear produce a 12-element Tensor when called from listener0, whereas structured_prior produces the relevant 12-element Tensor when over beta_bins.)

From what I can tell, it seems like config_enumerate isn’t registering in a recursive way. (But it’s also likely I don’t fully understand how config_enumerate works.)

(FYI: I’m down to continue the conversation on the PR if that makes more sense.)

Thanks for getting started! I may kind of slow for the next several days. Let’s discuss details of your code on the PR.