How does run time on setting `num_chains` scale with number of cores

I am trying to wrap my head around how the total runtime (compile time + execution time) scales as I change the num_chains and the number of cores I allocate to numpyro.

num_chains=1, 8 cores:

sample: 100%|█████| 150/150 [06:32<00:00,  2.62s/it, 63 steps of size 1.35e-01. acc. prob=0.62]

num_chains=2, 8 cores:

Running chain 0: 100%|███████████████████████████████████████| 150/150 [13:12<00:00,  5.28s/it]
Running chain 1: 100%|███████████████████████████████████████| 150/150 [13:12<00:00,  5.28s/it]

This makes sense. I see that in the same number of cores, doubling the num_chains doubles the total time for execution as well as the number of seconds per iteration.

num_chains=3, 24 cores:

Running chain 0: 100%|███████████████████████████████████████| 150/150 [16:47<00:00,  6.72s/it]
Running chain 1: 100%|███████████████████████████████████████| 150/150 [16:47<00:00,  6.72s/it]
Running chain 2: 100%|███████████████████████████████████████| 150/150 [16:47<00:00,  6.72s/it]

I fail to understand this though. Intuitively, I would have guessed that this would take the same time as the num_chains=1 and 8 cores case (the first case above). But this takes much longer (both the total time as well as the final shown time-per-iteration).

An intermediate step (when the runs are not finished), shows

Running chain 0:  56%|██████████████████████▍                 | 84/150 [14:16<04:15,  3.87s/it]
Running chain 1:  70%|███████████████████████████▎           | 105/150 [14:15<02:12,  2.95s/it]
Running chain 2:  70%|███████████████████████████▎           | 105/150 [14:36<02:07,  2.82s/it]

Here, I see that the time-per-iteration, is closer to what I observed in the first case (num_chains=1 and 8 cores). This leads me to believe that the execution time (excluding the initial compilation time) scales as we expect it to scale with the availability of cores vs. num_chains. But the compilation time does not make sense to me.

Could someone help me understand how to go about understanding the above observations? This would help me decide how many cores I should be allocating to increase the num_chains without increasing the runtime (as I would expect would happen in parallelization).

I think you should reach out to jax devs to get better explanations. We use pmap to perform parallel chain sampling. From my understanding, XLA compiling time grows up “non-linear” w.r.t. the number of operators - that might explain your observation.

Regardless, compiling time of your code is so slow. JAX team suggested to report those slow-to-compile programs (I guess the threshold here is about 1 minute) to them so they can make enhancements on their system. You can skip numpyro code by converting your model from

def model():
    x = numpyro.sample(..., dist.Normal(...))
    ...

into

def jax_fn(key):
    key, subkey = random.split(key)
    x = random.normal(subkey)
    ...

then try

jax.jit(jax_fn)(random.PRNGKey(0))

to see how slow it is, then report to them.

1 Like

So I followed up with a JAX maintainer in this thread and it seems that the insane compile time is due to the fact that we use only numpy or pythonic ops instead of the jax ops on the arguments that need to remain static (if you remember this is what we had arrived at from a previous thread here). This causes all the numpy or pythonic operations to roll out and as warned by the JAX maintainer: “those loops will be flattened and compilation tends to scale with about the square of the number of flattened operations.” This is killing us.

The only way around this seems to be to somehow pass static arguments to model. This, I think is possible by using the jit_model_args=True? But according to the documentation page, this then prevents us from being able to use num_chains > 1 and hence no parallel chains. Correct me if I am wrong.

P.S: We will do the above suggestion soon and report back here.