I was hoping to get some help with preventing numpyro and jax running operations across all cores so I can easily run many experiments on a server machine.
For context, I am using a 48 core server machine, and would like to run 12 different models (each using 4 cores) at the same time, all using numpyro. The models are different, so I’m not looking to save time on model compilation. Instead, I want to prevent jax operations using all the cores.
Each model has a
model_run.py script, which loads the model and data, and run the model. Roughly speaking, my workflow is to launch run
model_run.py for each model in the terminal.
In each file, I use:
import os os.environ["XLA_FLAGS"] = ( "--xla_cpu_multi_thread_eigen=false " "intra_op_parallelism_threads=1" ) import numpyro numpyro.set_host_device_count(num_chains) # run model
I hoped that this would ensure that each file only uses 4 cores each, so I can run 12 of them at the same time. But it seems that each process distributes it’s workload across all cores of the server machine.
Does anybody have any suggestions?