Can't get Numpyro to use GPU on Linux

I’m trying to run a numpyro model on a GCE GPU running Ubuntu, but cannot seem to get numpyro to use the GPU. If I try installing numpyro[cuda] from pip, it wants to downgrade jax and jaxlib to very old versions that don’t run my model at all. I’ve tried installing the latest numpyro and jax from conda-forge with conda, but the model fails with:

AttributeError: module '' has no attribute 'DialectRegistry'

The only way my model runs at all is with a standard pip install -U jax numpyro install, but it then no longer uses the GPU:

Unable to initialize backend 'gpu': NOT_FOUND: Could not find registered platform with name: "cuda". Available platform names are: Interpreter Host

and I can confirm:

>>> jax.lib.xla_bridge.get_backend().platform

This is a GCE VM with an NVIDIA A100 GPU and I have ensured that cuda-11 is installed, but this does not seem to make a difference. Any ideas on how to get this going?

I got this working by following the advice here, specifically installing conda install cuda -c nvidia and cudatoolkit-dev from conda-forge.

1 Like