GPU on Colab

I use this idiom to install numpyro in colab:

!pip install -q numpyro@git+https://github.com/pyro-ppl/numpyro

It runs fine, and jax says the GPU is available (if I select that runtime):

import jax
print("jax backend {}".format(jax.lib.xla_bridge.get_backend().platform))
jax.devices()

Currently if I run with num_chains=4 I get the warning that there is only 1 CPU (even in ColabPro).
Is there some way to leverage a single GPU to speedup parallel chains?