NumPyRO run vs update

Hi. I was running SVI for my model with NumPyRO to compare its speed-up vs PyRO version. When I’m using svi.run(random.PRNGKey(seed), num_steps, data) there is like a 5 times speed-up vs pyro version; however when I’m using svi.update(svi_state, data) with the same number of steps, there is approximately no difference with the pyro one. So I wanted to know whether there is a difference in the .run and .update with equal number of steps in numpyro. Also, I see no difference in the losses that they return but their runtime is much different.

1 Like

You need to jit compile your svi.update to get similar speed (e.g. this example uses fori_loop.

Oh thanks! Does that work for Pyro as well? Just by replacing the ELBO with its corresponding JIT version?

probably not. unfortunately pytorch jit doesn’t usually offer much of a speed-up, at least for complex use cases like pyro. though who knows pytorch is always evolving and you may observe speed-ups depending on the precise use case

1 Like

Yeah I checked the JITTraceEnum_Elbo and didn’t see much improvement. Thanks so much.

1 Like