Moving MCMC from CPU to GPU

Hi all,

I’ve read a few posts on the forum about how to use GPU for MCMC: Transfer SVI, NUTS and MCMC to GPU (Cuda), How to move MCMC run on GPU to CPU and Training on single GPU, but there are a few questions I still have on how to get the most out of numpyro. There is also this blog post comparing MCMC sampling methods on GPU, and although the model is built in pymc, it uses numpyro samplers.

  • The blog post compares “pymc_jax_gpu_parallel : PyMC with JAX backend (numpyro sampler) and GPU, running chains in sequence” and “pymc_jax_gpu_vectorized : PyMC with JAX backend (numpyro sampler) and GPU, running chains in parallel”. All models ran four chains. I found the terminology slightly confusing with pymc_jax_gpu_parallel running chains in sequence. Could you clarify if using a sequential method runs all the chains on GPU at the same time or not? (I have reached out the author but not received a reply.)
  • As you can see in the plots on the post, the vectorised method is clearly superior for large models. @fehiepsi has said in a few posts that the vectorised method is good if you draw a lot of chains – but this is only four chains. Can you explain why the vectorised method is best and whether it is always better, or on what sort of models might sequential perform better?
  • I have access to 2 GPUs. Does numpyro efficiently make use of both GPUs for MCMC sampling using numpyro.set_platform(“gpu”) or is there something else that needs to be done?
  • The final comments on the bottom of this and this ask how to run inference on GPU and analysis in arviz on CPU. @fehiepsi did you find a definitive way of doing this, or is best to save the output and load it up for arviz in a new CPU session?

Cheers,

Theo

I think it is better to do benchmark yourself for your specific problem. numpyro provides 3 methods sequential, vectorized, and parallel. Vectorized and sequential are useful for single devices. Parallel is useful for multiple devices. If you have 4 GPUs, then you can just use parallel method. If you have 1 GPU, you can use either sequential or vectorized method. The speed depends on problems (if chains are not well mixing, vectorized method would be slow) and it is easy to switch between those methods (by using chain_method=...). I’m not sure why using 1 GPU with vectorized method is faster than using 4 GPUs with parallel method. Probably the communication cost between devices are large or vectorized methods work well for some expensive computations in the model.

Regarding arviz, you can move all samples to CPU using jax.device_get and perform analysis on those CPU samples.

My feeling is, given the user had 1 GPU in the blog post, pymc_jax_gpu_parallel actually refers to sequential, and pymc_jax_gpu_vectorized refers to vectorised. Perhaps there wasn’t a warning that chain_method="parallel" wouldn’t work with 1 GPU in pymc.sampling_jax module.

So I think we’re comparing gpu_sequential and gpu_vectorized on 1 GPU (not 4). I guess the vectorized method might be faster simply as it is running all the chains at the same time? Unless there is another reason.

It’s a little tricky with 2 GPUs to decide between vectorised and parallel with 2 chains run sequentially on each GPU. I guess there’s no other way but to try it out.

So is this something like:

...
fit = mcmc.get_samples()
jax.device_get(fit)
posterior = az.from_dict(mcmc.get_samples()) # this happens on CPU

I assume I don’t need a numpyro.set_platform(“cpu”) somewhere between?

I actually prefer the az.from_dict(mcmc.get_samples()) to az.from_numpyro(mcmc)for this use case because I only want to save the posterior samples as xarray-like thing.

I think so. That’s what I understand from device_get doc. :slight_smile: