Pyro performs dramatically slower than PyMC3 with Normalizing Flows on stochastic-volatility model inference

Without your PyMC3 code I can’t say for sure, but I don’t think the two times you’re comparing are measuring the same things, either in terms of model implementation (the PyMC example you linked to uses a single GaussianRandomWalk distribution rather than a Python for loop) or inference (it seems from your description like you’re comparing VI in PyMC to NeutraHMC in Pyro, and both the flows and optimization procedures may be different and have different hyperparameters?). Still, I agree the difference is too high.

As a general comment on performance: models in Pyro that perform many operations on many small tensors and use Python control flow heavily, like the version of the stochastic volatility model you’ve written here, have serious performance issues. As tensor sizes go up and the fraction of time spent during inference actually performing numerical computations like matrix multiplication goes up, these issues fade away.

Unfortunately in our experience these issues are largely reflections of overhead and performance issues inherent in the design of PyTorch Tensors, autograd, and jit, which are not currently optimized for large graphs like the ones in your time series model (see e.g. Tensor overhead, slow JIT compilation) or for the operations used heavily in certain inference algorithms (e.g. advanced indexing, einsum).

It’s usually possible to work around these issues by vectorizing your model as suggested by @fritzo, and in some cases our parallel-scan-based distribution implementations like GaussianHMM can even provide significant performance boosts for long time series. We also encourage affected users who can’t use these workarounds to try out JAX and NumPyro, which currently does a much better job in this regime thanks to XLA’s optimizations and which will approach inference feature parity with Pyro over time.