Multi GPU num_samples

I was wondering if it’s possible to parallelize the number of particles used to estimate the ELBO over multiple GPUs (i.e. num_particles=40, and if we have 4 GPUs, then each GPU will estimate the ELBO from 10 particles, for each sample from the training set, and aggregate at the end)?

I’ve seen some topics discussing multi GPU distributed pyro.plate training, but I’m more interested in increasing the total number of particles to reduce the estimation variance.


Unfortunately I’m not aware of an easy and general way to accomplish this, at least as of this post. If you can provide a bit more context we might be able to suggest a workaround e.g. via clever use of torch.nn.DataParallel or JAX’s pmap in NumPyro.

Thanks for your response @eb8680_2. I want to train an SVI instance on a machine with four GPUs. I’m more concerned about variance reduction during training than accelerating the computation itself. I’m able to increase the number of particles I use to estimate the ELBO (per each epoch and each loaded data point) on a single GPU, until I max out it’s memory. However, I would like to increase the total number of particles I use (per each epoch and data point), such that all four GPUs take part in the estimation of the ELBO gradient (in a way, copy the same data to all four GPUs, on each one estimate 10 different particles, and combine them on GPU0 after all GPUs have finished, for example).

Although this is much more complicated, my main motivation is that, for example, if we estimate the mean of a Gaussian random variable, then using N samples should yield a variance reduction on the order of \sqrt(N), in theory. Unfortunately, I don’t have enough memory on a single GPU to achieve an order of magnitude variance reduction.

In fact, I’ve noticed that when I increase the number of particles for gradient estimate, the GPU memory occupancy increases as well. Is there a way to reduce the memory footprint with the increase in num_particles?

I hope this clarifies this a bit more.

Hi @pyrosol,
one approach you could take would be to use Pyro’s HorovodOptimizer as in the Horovod example. While this aims to reduce variance due to subsampling, it can also be made to reduce variance due to latent variable stochasticity.

Another simpler approach is to use more momentum or a lower learning rate. If you’re finding that gradients have high variance, two possible solutions are to (1) vectorize over many particles, or (2) sequentially compute multiple gradient estimates before making an update. Pyro doesn’t implement (2) because the common approach in deep learning is to instead (3) compute a single gradient estimate but use high momentum and average gradients over many learning steps each of which changes parameters only a tiny bit.

Good luck!

Thanks @fritzo - I’ll try your suggestions and see if any of them fits my needs.