Multiple chains slower


I am new to numpyro, so please bear with me. A few year ago I wrote in Stan a spatiotemporal model for analysing climate extremes. Recently, I decided to translate such model to numpyro to see if it would run faster (using NUTS). When I set “num_chains=1”, the model runs indeed 3x faster (on CPU) in numpyro and the results are identical to those in Stan, which is great. However, when I set “num_chains” to >1, the chains run much slower, almost an order of magnitude slower (1 hour for single chain vs 10 hours when using multiple chains). I was hoping you could help me understand why this might be happening and how to fix it. Of course, I could just run several chains separately and merge them together afterwards, but this is not ideal. I have a server available with CPU(s)=56, Core(s)=14, RAM=376GB. I am calling “numpyro.set_host_device_count(56)” at the beginning. I would be really grateful if you could put me out of my misery and help me figure out what’s going on.

By the way, in Stan, when I use multiple chains they run just as fast as when using a single chain.

Many thanks,

please refer to other threads in this forum like this one

Thanks very much for the link. However, this is not a case of some chains running slower, but rather all of the chains consistently running almost 10x slower. Yes, in Stan, some chains will run slightly faster or slower, depending on how they adapt. But the issue I am facing here seems to be that computational performance degrades dramatically when using parallel chains. I was wondering if there is something obvious I am missing or something I needed to be aware of, hence the reason for my message.

Edit: In numpyro, when I run 4 chains in parallel, the ‘top’ command shows one single PID with a CPU usage of 400%. In Stan, what I see is four PIDs each with a CPU usage of 100%. Does that mean the four threads are running on a single core?

XLA is not optimized for paralleling over CPU cores. If you run 4 chains, you can set host device count to 4. For large number of chains, you can use chain_method="vectorized" without setting host device count - but I guess this way is only fast on GPU. You might want to explore mpi4jax to launch MCMC single-chain runs on multiple processes.