Stop Numpyro Operations Using All Cores

Hi all,

I was hoping to get some help with preventing numpyro and jax running operations across all cores so I can easily run many experiments on a server machine.

For context, I am using a 48 core server machine, and would like to run 12 different models (each using 4 cores) at the same time, all using numpyro. The models are different, so I’m not looking to save time on model compilation. Instead, I want to prevent jax operations using all the cores.

Each model has a model_run.py script, which loads the model and data, and run the model. Roughly speaking, my workflow is to launch run model_run.py for each model in the terminal.

In each file, I use:

import os 

os.environ["XLA_FLAGS"] = (
    "--xla_cpu_multi_thread_eigen=false " "intra_op_parallelism_threads=1"
)

import numpyro
numpyro.set_host_device_count(num_chains)

# run model

I hoped that this would ensure that each file only uses 4 cores each, so I can run 12 of them at the same time. But it seems that each process distributes it’s workload across all cores of the server machine.

Does anybody have any suggestions?

Thanks!

Hi @mrinank_sharma, could you try the following flags

os.environ["XLA_FLAGS"] = (
    "--xla_force_host_platform_device_count=4 "
    "--xla_cpu_multi_thread_eigen=false "
    "--intra_op_parallelism_threads=1"
)

Hey,

I just tried this, but it doesn’t seem to make a difference. I run 6 experiments, and the machine has about 60% CPU utilisation on all of the cores.

I guess I wouldn’t mind if doing this does not slow down the models, but I can’t tell as the progressbar isn’t supported for the parallel model.

FYI, progress bars for parallel chains will be supported in the next release. I will take a look at this issue. In the meantime, could you try this solution (which is recommeded by a JAX dev)? Maybe the flag inter_op_parallelism_threads=1 is also relevant… Also, maybe it is TF_XLA_FLAGS rather than XLA_FLAGS

os.environ["TF_XLA_FLAGS"] = (
    "--xla_cpu_multi_thread_eigen=false "
    "--intra_op_parallelism_threads=1"
)
1 Like

Thanks! I’m looking forward to having progress bars for parallel chains :slight_smile:

In case anybody else has a similar usecase, I ended up setting processor affinities using taskset -c. This seems to work, and is suitable for my usecase. Thanks for the suggestion.

1 Like