Speed up SVI?

I’m trying to setup a demo notebook on Google Collab (K80) but I face the problem of slow SVI. I try a AutoMultivariateNormal with Adam optimizer
and loss=Trace_ELBO(num_particles=1) with 5_000 steps. It seems that it will take 6h or so.

Is there some way to speed up the optimization?


Your question is a bit vague. I think general suggestions are:

  • using subsampling if data is large
  • using diagonal normal guide if the number of latent dimensions is large
  • using gpu/tpu if data is large

Hello @fehiepsi
Well, I confess that my question is a bit vague.
The number of data is rather small (500), the latent dim is of the order of 10, I use K80 GPU on Google Collab.

Does the progress_bar is a bottleneck ?

Re progress bar: I don’t think so. For 5000 steps it might just take 1 second. I’m not sure what costs you 6 hours. Pretty strange to me…

It is strange. Could you share the colab? Unless model_spl has very slow code to run, SVI shouldn’t be such slow.

For sure I can share code but I don’t how to do so. I can send you the nb, there is 1 pip/git clone (the code we develop we colleagues), and the “data” are generated on the fly.

Are u still interested to have a look at my nb? Thanks

Yeah sure, you can use https://gist.github.com/ or https://colab.research.google.com/ for the code or you can just share some of its code below. It’s very strange that SVI is such slow with small data and small dimension. The first thing I will try is to remove numpyro, use concrete values for the latent variable, and try to see the speed of that model_spl function. If it is slow, then then the slowness is expected

@campagne Your model took 15s to generate data in CPU, so I expect 1 SVI step will be around that time. So in 1 hour, you will get about 240 SVI steps. I would suggest to improve the speed of your model first before using SVI. This is not the problem of SVI or auto guides as far as I see.

Hum, ok. I guess we are not experienced enough to take 100% benefit of JAX/JIT. Thanks.