GPU on Colab

Hello all,

I am trying out Numpyro hoping I will be able to use it on GPU with (relatively) large dataset. I also hope to be able to use Funsor-based enumeration.

I want to run it on Colab so I choose gpu-backed runtime and install the packages by doing the following:

!pip install numpyro
!pip install funsor

Then I load the packages and check jax’s device:

import jax
import numpyro
import funsor

from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

Unfortunately it shows “cpu”. Dou you know how to enable gpu for jax in colab?

Many thanks,
Szymon

Hi @Elchorro, you can use numpyro.set_platform to either using CPU, GPU, or TPU on colab. For cloud TPU, you need some extra configs as in jax demo.

2 Likes

I tried this, but it did not work.
I get the error Unknown backend GPU. Available: ['interpreter', 'cpu']

The problem seems to be that pip install numpyro causes this to happen: Uninstalling jaxlib-0.1.65+cuda110. The cuda version of jax (built in to colab) gets removed and replaced with a CPU version: Successfully installed jax-0.2.10 jaxlib-0.1.62 numpyro-0.6.0

ziatdinovmax posted the solution here:

@murphyk Currently, we pin jax/jaxlib versions in each release. From the next release, we’ll relax that restriction. For now, the simplest way is to add

!pip install -q numpyro@git+https://github.com/pyro-ppl/numpyro

to the top of your colab notebook, like bayesian regression tutorial

Thanks, that is much simpler and faster :slight_smile:

I use this idiom to install numpyro in colab:

!pip install -q numpyro@git+https://github.com/pyro-ppl/numpyro

It runs fine, and jax says the GPU is available (if I select that runtime):

import jax
print("jax backend {}".format(jax.lib.xla_bridge.get_backend().platform))
jax.devices()

Currently if I run with num_chains=4 I get the warning that there is only 1 CPU (even in ColabPro).
Is there some way to leverage a single GPU to speedup parallel chains?

I think if you can use MCMC(..., chain_method="vectorized") to take advantage of GPU. But to my experience, it is only fast if num_chains=100 e.g.