CPU utilization

I am running SVI with a numpyro model but it is not utilizing all of my cpus when I call run. Am I doing something incorrectly. For reference, i have 10 cpus but the average utilization on my container is ~2

        numpyro.set_host_device_count(os.cpu_count())
        os.environ["CUDA_VISIBLE_DEVICES"] = ""

I don’t know what’s going on under the hood, but maybe the answers in this thread are useful: JAX running in CPU only mode only uses a single core · Issue #5022 · google/jax · GitHub