Hello!
I am porting a stan model to numpyro and trying to compute the MAP estimate of the parameters. I am following the docstring example for Minimize:
optimizer = Minimize(method="BFGS")
guide = AutoDelta(model)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
init_state = svi.init(random.PRNGKey(0), ...)
optimal_state, loss = svi.update(init_state, ...)
The results match those of the stan model but the whole process takes about 7 seconds to run which seems slow to me because stan is able to do it in about 0.5 seconds. When I profile the code with cProfile, I see that 5.8 seconds are spent running jax/scipy/optimize/minimize.py
. I am not sure if the slowness is due to limitations of jax or because I am doing something in my model definition to make things slower.
Are there any tools/tips you have for profiling numpyro models to identify bottlenecks?
Are there any common pitfalls in model definition that decrease performance?
Would the optimization be faster if I wrote the guide myself as opposed to using AutoDelta
?
Thanks for the help!