NumPyro vs. Pyro, SVI in Discrete Latent Variable Models

Hi. As it is said, there is a great speed-up when implementing MCMC methods in NumPyro compared to Pyro.

But from my own experience, I have had a magnificent speed-up for SVI with NumPyro when I was using a giant model that got discrete latent variables. However, I have seen some cases that there isn’t a significant difference in terms of speed between these two.

So my question is, for SVI, is this speed-up due to better enumeration in NumPyro or other reasons? Unfortunate I cannot share my model now.

Thanks in advance for your answers!

As far as I know, two backends should have similar behavior. So it is likely the difference comes from jax vs pytorch