When I run MCMC in numpyro, it always says “There are not enough devices to run parallel chains: expected 2 but got 1. Chains will be drawn sequentially.” This is despite the fact that I am using Colab Pro, and the Google team told me “External colab standard VMs are 2 vCPU, highmem VMs are 4vCPU.” And
!cat /proc/cpuinfo shows 4 CPUs (when in high RAM mode).
Below is the code I tried to make it detect the CPUs:
import os os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=2" numpyro.set_host_device_count(2) print(jax.lib.xla_bridge.device_count()) # prints 1 print(jax.local_device_count()) # prints 1