How Can I Speed Up MAP Estimation?

Hello!

I am porting a stan model to numpyro and trying to compute the MAP estimate of the parameters. I am following the docstring example for Minimize:

    optimizer = Minimize(method="BFGS")
    guide = AutoDelta(model)
    svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
    init_state = svi.init(random.PRNGKey(0), ...)
    optimal_state, loss = svi.update(init_state, ...)

The results match those of the stan model but the whole process takes about 7 seconds to run which seems slow to me because stan is able to do it in about 0.5 seconds. When I profile the code with cProfile, I see that 5.8 seconds are spent running jax/scipy/optimize/minimize.py. I am not sure if the slowness is due to limitations of jax or because I am doing something in my model definition to make things slower.

Are there any tools/tips you have for profiling numpyro models to identify bottlenecks?
Are there any common pitfalls in model definition that decrease performance?
Would the optimization be faster if I wrote the guide myself as opposed to using AutoDelta?

Thanks for the help!

1 Like

In my opinion, it is unlikely that the model definition or guide adds many bottlenecks to your inference. There are many factors that affect the speed:

  • the number of BFGS iterations are different (which is affected by floating precision, tolerant parameter, step size,…)
  • the implementations of bfgs are different
  • compiling time of a jax program (which should take a large portion of that 5.8s)

@fehiepsi Thank you for the reply! Yes, you are right that the compilation is taking the largest part of the runtime! I need to look into why it’s being called 127 times. I would have expected it to be called only once?