I am new to numpyro, so please bear with me. A few year ago I wrote in Stan a spatiotemporal model for analysing climate extremes. Recently, I decided to translate such model to numpyro to see if it would run faster (using NUTS). When I set “num_chains=1”, the model runs indeed 3x faster (on CPU) in numpyro and the results are identical to those in Stan, which is great. However, when I set “num_chains” to >1, the chains run much slower, almost an order of magnitude slower (1 hour for single chain vs 10 hours when using multiple chains). I was hoping you could help me understand why this might be happening and how to fix it. Of course, I could just run several chains separately and merge them together afterwards, but this is not ideal. I have a server available with CPU(s)=56, Core(s)=14, RAM=376GB. I am calling “numpyro.set_host_device_count(56)” at the beginning. I would be really grateful if you could put me out of my misery and help me figure out what’s going on.
By the way, in Stan, when I use multiple chains they run just as fast as when using a single chain.