First of all, thank you for this great library!
I am trying to run a model that has a computationally intensive likelihood function. That is, in the definition of my numpyro model, I have a function (written in JAX) that takes a few seconds to evaluate once. I am trying to make it faster using either the JAX built-in vectorization or some form of parallel computing.
To start, I am just trying to run 1 chain one a single node that has 15 cpus. (Here I am using SLURM with --cpus-per-task=15
)
In my python file, I have
import numpyro
numpyro.set_host_device_count(15)
print('jax.local_device_count ',jax.local_device_count(backend=None)) # prints 15
print('jax.device_count ',jax.device_count(backend=None)) # prints 15
…
mcmc = MCMC(
nuts_kernel,
num_samples=50,
num_warmup=50,
num_chains=1,
chain_method="vectorized",
progress_bar=False,
)
However, I am only observing ~ 140% cpu usage instead of 1500%. I would have thought, according to the set_host_device_count
documentation, that “by default, XLA considers all CPU cores as one device. This utility tells XLA that there are n host (CPU) devices available to use. ” I am doing something wrong here?
Moreover, when I try
mcmc = MCMC(
nuts_kernel,
num_samples=50,
num_warmup=50,
num_chains=15,
chain_method=“parallel”,
progress_bar=False,
)
I do get 1300% cpu usage as expected. But since my model is slow, I was hoping to use the entire node to run 1 chain, and embracingly parallel chains over different nodes (if that’s possible). Thank you so much for the help!
Best,
Alan