Understanding chain_method

First of all, thank you for this great library!

I am trying to run a model that has a computationally intensive likelihood function. That is, in the definition of my numpyro model, I have a function (written in JAX) that takes a few seconds to evaluate once. I am trying to make it faster using either the JAX built-in vectorization or some form of parallel computing.

To start, I am just trying to run 1 chain one a single node that has 15 cpus. (Here I am using SLURM with --cpus-per-task=15)

In my python file, I have

import numpyro
numpyro.set_host_device_count(15)
print('jax.local_device_count ',jax.local_device_count(backend=None)) # prints 15
print('jax.device_count ',jax.device_count(backend=None)) # prints 15
…
mcmc = MCMC(
    nuts_kernel,
    num_samples=50,
    num_warmup=50,
    num_chains=1,
    chain_method="vectorized",
    progress_bar=False,
)

However, I am only observing ~ 140% cpu usage instead of 1500%. I would have thought, according to the set_host_device_count documentation, that “by default, XLA considers all CPU cores as one device. This utility tells XLA that there are n host (CPU) devices available to use. ” I am doing something wrong here?

Moreover, when I try

mcmc = MCMC(
    nuts_kernel,
    num_samples=50,
    num_warmup=50,
    num_chains=15,
    chain_method=“parallel”,
    progress_bar=False,
)

I do get 1300% cpu usage as expected. But since my model is slow, I was hoping to use the entire node to run 1 chain, and embracingly parallel chains over different nodes (if that’s possible). Thank you so much for the help!

Best,
Alan

If num_chains is 1, it will use 1 device to perform sampling. You might want to set host device count to 1. If you want to draw multiple chains using vectorization method, then it’s also better to set host device count to 1.

Only set higher host devices if you want to use parallel method. As you expect, you will get 1300% cpu usage in such case.

If you want to leverage all cpu usage with 1 host device, it’s better to reach out to jax dev to achieve such behavior.

Thank you so much! For my use case (and I suppose for computationally intensive likelihood in general), it might be useful to leverage all cpu’s with 1 host device.

Just to make sure I understand how numpyro works. Is it right to say num_chains tells the sampler to jax.pmap over local host devices? If that is true, does it make sense (or is it good practice) to set

# multiple host device
numpyro.set_host_device_count(15)
# and num_chains = 1
mcmc = MCMC(
    nuts_kernel,
    num_samples=50,
    num_warmup=50,
    num_chains=1,
    chain_method="parallel",
    progress_bar=False,
)

and then uses ‘jax.pmap’ to parallelize my likelihood function with in the model definition? Or will the pmap within the model definition conflict with what the sampler is doing?

num_chains is telling how many mcmc chains you need. I don’t think either num_chains or chain_method is useful for your usage case. It is better to reach out to jax devs to see how to evaluate your likelihood using all cpu cores. Probably using pmap for the likelihood will work - I don’t know.

If you have multiple gpu/tpu devices, you can distribute computations using sharding as in an example in MCMC docs. I’m not sure about cpus.