I am trying to wrap my head around how the total runtime (compile time + execution time) scales as I change the num_chains
and the number of cores I allocate to numpyro.
num_chains=1
, 8 cores:
sample: 100%|█████| 150/150 [06:32<00:00, 2.62s/it, 63 steps of size 1.35e-01. acc. prob=0.62]
num_chains=2
, 8 cores:
Running chain 0: 100%|███████████████████████████████████████| 150/150 [13:12<00:00, 5.28s/it]
Running chain 1: 100%|███████████████████████████████████████| 150/150 [13:12<00:00, 5.28s/it]
This makes sense. I see that in the same number of cores, doubling the num_chains
doubles the total time for execution as well as the number of seconds per iteration.
num_chains=3
, 24 cores:
Running chain 0: 100%|███████████████████████████████████████| 150/150 [16:47<00:00, 6.72s/it]
Running chain 1: 100%|███████████████████████████████████████| 150/150 [16:47<00:00, 6.72s/it]
Running chain 2: 100%|███████████████████████████████████████| 150/150 [16:47<00:00, 6.72s/it]
I fail to understand this though. Intuitively, I would have guessed that this would take the same time as the num_chains=1
and 8 cores case (the first case above). But this takes much longer (both the total time as well as the final shown time-per-iteration).
An intermediate step (when the runs are not finished), shows
Running chain 0: 56%|██████████████████████▍ | 84/150 [14:16<04:15, 3.87s/it]
Running chain 1: 70%|███████████████████████████▎ | 105/150 [14:15<02:12, 2.95s/it]
Running chain 2: 70%|███████████████████████████▎ | 105/150 [14:36<02:07, 2.82s/it]
Here, I see that the time-per-iteration, is closer to what I observed in the first case (num_chains=1
and 8 cores). This leads me to believe that the execution time (excluding the initial compilation time) scales as we expect it to scale with the availability of cores vs. num_chains
. But the compilation time does not make sense to me.
Could someone help me understand how to go about understanding the above observations? This would help me decide how many cores I should be allocating to increase the num_chains
without increasing the runtime (as I would expect would happen in parallelization).