I use this idiom to install numpyro in colab:
!pip install -q numpyro@git+https://github.com/pyro-ppl/numpyro
It runs fine, and jax says the GPU is available (if I select that runtime):
import jax
print("jax backend {}".format(jax.lib.xla_bridge.get_backend().platform))
jax.devices()
Currently if I run with num_chains=4
I get the warning that there is only 1 CPU (even in ColabPro).
Is there some way to leverage a single GPU to speedup parallel chains?