First of all, thank you for this great library!
I am trying to run a model that has a computationally intensive likelihood function. That is, in the definition of my numpyro model, I have a function (written in JAX) that takes a few seconds to evaluate once. I am trying to make it faster using either the JAX built-in vectorization or some form of parallel computing.
To start, I am just trying to run 1 chain one a single node that has 15 cpus. (Here I am using SLURM with
In my python file, I have
import numpyro numpyro.set_host_device_count(15) print('jax.local_device_count ',jax.local_device_count(backend=None)) # prints 15 print('jax.device_count ',jax.device_count(backend=None)) # prints 15 … mcmc = MCMC( nuts_kernel, num_samples=50, num_warmup=50, num_chains=1, chain_method="vectorized", progress_bar=False, )
However, I am only observing ~ 140% cpu usage instead of 1500%. I would have thought, according to the
set_host_device_count documentation, that “by default, XLA considers all CPU cores as one device. This utility tells XLA that there are n host (CPU) devices available to use. ” I am doing something wrong here?
Moreover, when I try
mcmc = MCMC( nuts_kernel, num_samples=50, num_warmup=50, num_chains=15, chain_method=“parallel”, progress_bar=False, )
I do get 1300% cpu usage as expected. But since my model is slow, I was hoping to use the entire node to run 1 chain, and embracingly parallel chains over different nodes (if that’s possible). Thank you so much for the help!