How to use cores across nodes?

Is there a special key or flag that we have to use in numpyro to enable the use of cores across nodes? Let’s say we want to run a 80 core job. Our cluster has 40 cores per node. So, we are allocating 2 nodes with 40 cores each using a SLURM script. Thereafter, we login to the nodes and do an htop -u $USER. This shows us that ALL the cores in one of the nodes is being used and NONE in the other node.

From going though various JAX discussion and issues, I guess the answer is to figure out how to run JAX codes on multi-nodes? It doesn’t seem to be anything to do with numpyro.

I have found a way to parallelize a JAX code across nodes using https://mpi4jax.readthedocs.io. But is there a reason to suspect that this would interfere with the way numpyro internals are written (for example pmap etc)?

I don’t know how pmap works with other frameworks but you can try. By the way, using the multi-chain feature of MCMC is unnecessary. You can just use MCMC with q single chain, then parallel it the way you want. For example, the following code should give similar performance as if you use num_chains=4:

def get_samples(rng_keys):
    mcmc = MCMC(..., num_chains=1)
    mcmc.run()
    return mcmc.get_samples()

chain_samples = jax.pmap(get_samples)(random.split(random.PRNGKey(0), 4))
1 Like

Thanks a lot. This confirms what we have been guessing till now. I assume though that this is equivalent to running 4 different instances of the single chain inversion on a single node with multiple cores. It would, perhaps, still not detect cores in other nodes. Sigh!