How to leverage multiple CPUs in colab

When I run MCMC in numpyro, it always says “There are not enough devices to run parallel chains: expected 2 but got 1. Chains will be drawn sequentially.” This is despite the fact that I am using Colab Pro, and the Google team told me “External colab standard VMs are 2 vCPU, highmem VMs are 4vCPU.” And !cat /proc/cpuinfo shows 4 CPUs (when in high RAM mode).

Below is the code I tried to make it detect the CPUs:

import os
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=2"
numpyro.set_host_device_count(2)
print(jax.lib.xla_bridge.device_count()) # prints 1
print(jax.local_device_count()) #  prints 1

It turns out that if you put this magic incantation at the start of your colab, it will work :slight_smile:

import os
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=4"
1 Like

Could you try again with only import numpyro; numpyro.set_host_device_count(4)? It works for me on Colab non-pro.

Edit yeah, you need to run it at the start of your program