When I run MCMC in numpyro, it always says “There are not enough devices to run parallel chains: expected 2 but got 1. Chains will be drawn sequentially.” This is despite the fact that I am using Colab Pro, and the Google team told me “External colab standard VMs are 2 vCPU, highmem VMs are 4vCPU.” And !cat /proc/cpuinfo
shows 4 CPUs (when in high RAM mode).
Below is the code I tried to make it detect the CPUs:
import os
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=2"
numpyro.set_host_device_count(2)
print(jax.lib.xla_bridge.device_count()) # prints 1
print(jax.local_device_count()) # prints 1