I am trying out Numpyro hoping I will be able to use it on GPU with (relatively) large dataset. I also hope to be able to use Funsor-based enumeration.
I want to run it on Colab so I choose gpu-backed runtime and install the packages by doing the following:
!pip install numpyro !pip install funsor
Then I load the packages and check jax’s device:
import jax import numpyro import funsor from jax.lib import xla_bridge print(xla_bridge.get_backend().platform)
Unfortunately it shows “cpu”. Dou you know how to enable gpu for jax in colab?