GPU on Colab

I am trying out Numpyro hoping I will be able to use it on GPU with (relatively) large dataset. I also hope to be able to use Funsor-based enumeration.

I want to run it on Colab so I choose gpu-backed runtime and install the packages by doing the following:

!pip install numpyro
!pip install funsor

Then I load the packages and check jax’s device:

import jax
import numpyro
import funsor

from jax.lib import xla_bridge

Unfortunately it shows “cpu”. Dou you know how to enable gpu for jax in colab?

Hi @Elchorro, you can use numpyro.set_platform to either using CPU, GPU, or TPU on colab. For cloud TPU, you need some extra configs as in jax demo.

