Hi. I was running SVI for my model with NumPyRO to compare its speed-up vs PyRO version. When I’m using svi.run(random.PRNGKey(seed), num_steps, data)
there is like a 5 times speed-up vs pyro version; however when I’m using svi.update(svi_state, data)
with the same number of steps, there is approximately no difference with the pyro one. So I wanted to know whether there is a difference in the .run
and .update
with equal number of steps in numpyro. Also, I see no difference in the losses that they return but their runtime is much different.
2 Likes
You need to jit compile your svi.update to get similar speed (e.g. this example uses fori_loop.
1 Like
Oh thanks! Does that work for Pyro as well? Just by replacing the ELBO with its corresponding JIT version?
probably not. unfortunately pytorch jit doesn’t usually offer much of a speed-up, at least for complex use cases like pyro. though who knows pytorch is always evolving and you may observe speed-ups depending on the precise use case
2 Likes
Yeah I checked the JITTraceEnum_Elbo
and didn’t see much improvement. Thanks so much.
1 Like