I am porting a stan model to numpyro and trying to compute the MAP estimate of the parameters. I am following the docstring example for Minimize:
optimizer = Minimize(method="BFGS") guide = AutoDelta(model) svi = SVI(model, guide, optimizer, loss=Trace_ELBO()) init_state = svi.init(random.PRNGKey(0), ...) optimal_state, loss = svi.update(init_state, ...)
The results match those of the stan model but the whole process takes about 7 seconds to run which seems slow to me because stan is able to do it in about 0.5 seconds. When I profile the code with cProfile, I see that 5.8 seconds are spent running
jax/scipy/optimize/minimize.py. I am not sure if the slowness is due to limitations of jax or because I am doing something in my model definition to make things slower.
Are there any tools/tips you have for profiling numpyro models to identify bottlenecks?
Are there any common pitfalls in model definition that decrease performance?
Would the optimization be faster if I wrote the guide myself as opposed to using
Thanks for the help!