Torch, jax, funsor, etc: I am a bit lost ;)

Hi,
This is a somewhat open ended question, but I was wondering what is the rationale behind supporting different backends and how to decide, as a user, which way to go.

I am starting to integrate probabilistic models in existing production systems which support the research of scientists at my institution. Although the books and tutorials available made me think that PyMC or Stan were the more mature ecosystems, I started using Pyro because we use Pytorch for other parts of the system.

Since we are just starting with PPLs, we may still have time to review our choices, but I lack advice. Has anyone on this forum some suggestions on which are the criteria to take into account to choose between Pyro and NumPyro and what will be the implications of the advancement of Funsor on these choices?

Our main constraint is that we have to make choices for our platform and we will have to maintain the code produced by PhD students and faculty several years down the road. And this is the kind of nightmare I would like to avoid.

Sorry for not being more specific and thanks for Pyro and all the work you are putting on it!

2 Likes

Just my opinion, for a production system, it is better to stick with Pyro because your organization is already familiar with PyTorch, so if there are bugs that happen, your team can fix them faster. You can consider using NumPyro if you need speed in MCMC methods or you find some of JAX features are helpful to your research programs. I think you can switch between Pyro/NumPyro without much difficulty because API for the main features are similar. funsor supports both backends so it does not matter much for the choice.

@garjola, I feel your pain, maintenance is hard work! I basically agree with @fehiepsi: Pyro is older and has more features and documentation and error handling, whereas NumPyro has some fancy new inference algorithms that are very fast. Regarding JAX-vs-PyTorch, JAX changes more often than PyTorch, so code in Pyro+PyTorch will be easier to maintain than NumPyro+JAX, as versions evolve.

You probably don’t need to worry about Funsor: it is an intermediate layer where we put algorithms that are common to Pyro and NumPyro. Funsor sits in between like Pyro-Funsor-PyTorch or NumPyro-Funsor-JAX. As an intermediate layer, Funsor’s machinery should be mostly hidden from users of NumPyro and Pyro.

Our rationale for maintaining two systems is two-fold: (1) JAX is way faster for MCMC, so in some applications it is worth the maintenance cost; (2) JAX/PyTorch are both popular libraries and we on the Pyro team would like to hedge our bet as to which will be more popular in 2 or 5 years. By keeping Pyro and NumPyro mostly compatible, we’re hoping it will be possible to port models between frameworks (possible, but still hard work).

1 Like

Hi @fehiepsi and @fritzo,

Thanks for your feedback. It is reassuring to have these elements. I have already experimented the speed difference for MCMC between both backends and I understand more the interest. Porting the code between Pyro and NumPyro was easy in my case.

If MCMC is the main difference between both, our choice is easier since we have 2 main use cases in our institution: the people working with “big data” and DL (pure Pytorch) + GP (GPytorch) who are now introducing some PP (and therefore VI and not MCMC) and the people of the “small data world” who are more interested in MCMC and seldom in VI.

So actually, the API similarity between Pyro and NumPyro is a big win for us. Although having MCMC inference in Pyro being as fast as the one in NumPyro would be great, but I understand that this is not possible even with Pytorch’s JIT compiler.

Thanks again for your time!