Numpyro 0.8.0 needs JaxLib 0.1.73 : pb GPU

Hi,
I experience a pb on a Linux machine connected to V100 GPU

numpyro version '0.8.0'
jax version '0.2.24'

Python 3.8.12 (default, Oct 12 2021, 13:49:34) 
[GCC 7.5.0] :: Anaconda, Inc. on linux

I Have install a fresh version of Anaconda3. Then, create an conda environment with python 3.8.

Then, Numpyro installed thanks to conda install -c conda-forge numpyro

But, the pb is that Numpyro wants to install jaxlib 0.1.73 which does not see the GPU.
I had to first pip uninstall jax jaxliband then

pip install --upgrade jax jaxlib==0.1.72+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html

So, I do not know if one experiences the same behavior with the new jaxlib 0.1.73 (on cuda11.3) Here is a complete Cuda install information

  • cuda_version: ‘11-3’
  • cuda_release: ‘11.3.0-1’
  • cuda_drivers_version: ‘465.19.01-1’
  • cudnn_version: ‘8.2.0.53’

From a fresh Anaconda3 and creation of a Conda environment:

conda install -c conda-forge numpyro
pip install --upgrade “jax[cuda]” -f https://storage.googleapis.com/jax-releases/jax_releases.html

numpyro 0.8.0 pyhd8ed1ab_0 conda-forge
jax 0.2.24 pyhd8ed1ab_0 conda-forge
jaxlib 0.1.73+cuda11.cudnn82 pypi_0 pypi

solves the problem. So, there was certainly some problem with my installation after some pip/conda install of other packages.