I’m trying to run a numpyro model on a GCE GPU running Ubuntu, but cannot seem to get numpyro to use the GPU. If I try installing numpyro[cuda]
from pip, it wants to downgrade jax and jaxlib to very old versions that don’t run my model at all. I’ve tried installing the latest numpyro and jax from conda-forge with conda, but the model fails with:
AttributeError: module 'jaxlib.mlir._mlir_libs._mlir.ir' has no attribute 'DialectRegistry'
The only way my model runs at all is with a standard pip install -U jax numpyro
install, but it then no longer uses the GPU:
Unable to initialize backend 'gpu': NOT_FOUND: Could not find registered platform with name: "cuda". Available platform names are: Interpreter Host
and I can confirm:
>>> jax.lib.xla_bridge.get_backend().platform
This is a GCE VM with an NVIDIA A100 GPU and I have ensured that cuda-11 is installed, but this does not seem to make a difference. Any ideas on how to get this going?