How Can I Speed Up MAP Estimation?


I am porting a stan model to numpyro and trying to compute the MAP estimate of the parameters. I am following the docstring example for Minimize:

    optimizer = Minimize(method="BFGS")
    guide = AutoDelta(model)
    svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
    init_state = svi.init(random.PRNGKey(0), ...)
    optimal_state, loss = svi.update(init_state, ...)

The results match those of the stan model but the whole process takes about 7 seconds to run which seems slow to me because stan is able to do it in about 0.5 seconds. When I profile the code with cProfile, I see that 5.8 seconds are spent running jax/scipy/optimize/ I am not sure if the slowness is due to limitations of jax or because I am doing something in my model definition to make things slower.

Are there any tools/tips you have for profiling numpyro models to identify bottlenecks?
Are there any common pitfalls in model definition that decrease performance?
Would the optimization be faster if I wrote the guide myself as opposed to using AutoDelta?

Thanks for the help!

In my opinion, it is unlikely that the model definition or guide adds many bottlenecks to your inference. There are many factors that affect the speed:

  • the number of BFGS iterations are different (which is affected by floating precision, tolerant parameter, step size,…)
  • the implementations of bfgs are different
  • compiling time of a jax program (which should take a large portion of that 5.8s)

@fehiepsi Thank you for the reply! Yes, you are right that the compilation is taking the largest part of the runtime! I need to look into why it’s being called 127 times. I would have expected it to be called only once?