Hi all,
I’ve read a few posts on the forum about how to use GPU for MCMC: Transfer SVI, NUTS and MCMC to GPU (Cuda), How to move MCMC run on GPU to CPU and Training on single GPU, but there are a few questions I still have on how to get the most out of numpyro. There is also this blog post comparing MCMC sampling methods on GPU, and although the model is built in pymc, it uses numpyro samplers.
- The blog post compares "
pymc_jax_gpu_parallel
: PyMC with JAX backend (numpyro sampler) and GPU, running chains in sequence" and "pymc_jax_gpu_vectorized
: PyMC with JAX backend (numpyro sampler) and GPU, running chains in parallel". All models ran four chains. I found the terminology slightly confusing withpymc_jax_gpu_parallel
running chains in sequence. Could you clarify if using a sequential method runs all the chains on GPU at the same time or not? (I have reached out the author but not received a reply.) - As you can see in the plots on the post, the vectorised method is clearly superior for large models. @fehiepsi has said in a few posts that the vectorised method is good if you draw a lot of chains – but this is only four chains. Can you explain why the vectorised method is best and whether it is always better, or on what sort of models might sequential perform better?
- I have access to 2 GPUs. Does numpyro efficiently make use of both GPUs for MCMC sampling using
numpyro.set_platform(“gpu”)
or is there something else that needs to be done? - The final comments on the bottom of this and this ask how to run inference on GPU and analysis in arviz on CPU. @fehiepsi did you find a definitive way of doing this, or is best to save the output and load it up for arviz in a new CPU session?
Cheers,
Theo