Hello,
I’m trying to setup a demo notebook on Google Collab (K80) but I face the problem of slow SVI. I try a AutoMultivariateNormal
with Adam
optimizer
and loss=Trace_ELBO(num_particles=1)
with 5_000 steps. It seems that it will take 6h or so.
Is there some way to speed up the optimization?
Thanks
Your question is a bit vague. I think general suggestions are:
- using subsampling if data is large
- using diagonal normal guide if the number of latent dimensions is large
- using gpu/tpu if data is large
Hello @fehiepsi
Well, I confess that my question is a bit vague.
The number of data is rather small (500), the latent dim is of the order of 10, I use K80 GPU on Google Collab.
Does the progress_bar is a bottleneck ?
Re progress bar: I don’t think so. For 5000 steps it might just take 1 second. I’m not sure what costs you 6 hours. Pretty strange to me…
It is strange. Could you share the colab? Unless model_spl
has very slow code to run, SVI shouldn’t be such slow.
For sure I can share code but I don’t how to do so. I can send you the nb, there is 1 pip/git clone (the code we develop we colleagues), and the “data” are generated on the fly.
Are u still interested to have a look at my nb? Thanks
Yeah sure, you can use https://gist.github.com/ or https://colab.research.google.com/ for the code or you can just share some of its code below. It’s very strange that SVI is such slow with small data and small dimension. The first thing I will try is to remove numpyro, use concrete values for the latent variable, and try to see the speed of that model_spl function. If it is slow, then then the slowness is expected
@campagne Your model took 15s to generate data in CPU, so I expect 1 SVI step will be around that time. So in 1 hour, you will get about 240 SVI steps. I would suggest to improve the speed of your model first before using SVI. This is not the problem of SVI or auto guides as far as I see.
Hum, ok. I guess we are not experienced enough to take 100% benefit of JAX/JIT. Thanks.