Hi,
I experience a pb on a Linux machine connected to V100 GPU
numpyro version '0.8.0'
jax version '0.2.24'
Python 3.8.12 (default, Oct 12 2021, 13:49:34)
[GCC 7.5.0] :: Anaconda, Inc. on linux
I Have install a fresh version of Anaconda3. Then, create an conda environment with python 3.8.
Then, Numpyro installed thanks to conda install -c conda-forge numpyro
But, the pb is that Numpyro wants to install jaxlib 0.1.73 which does not see the GPU.
I had to first pip uninstall jax jaxlib
and then
pip install --upgrade jax jaxlib==0.1.72+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
So, I do not know if one experiences the same behavior with the new jaxlib 0.1.73 (on cuda11.3) Here is a complete Cuda install information
- cuda_version: ‘11-3’
- cuda_release: ‘11.3.0-1’
- cuda_drivers_version: ‘465.19.01-1’
- cudnn_version: ‘8.2.0.53’