Chains in Parallel slows down per-chain time

Hi all,

I’ve been experimenting with the parallel option for running multiple chains at once. There seems to be a slowdown when running multiple chains at the same time, which I wouldn’t have expected. I wanted to check if this behaviour was expected.

For example:

num_chains = 4

numpyro.set_host_device_count(num_chains)

...

nuts_kernel = NUTS(
        some_model, 
        init_strategy=init_to_median,
        target_accept_prob=target_accept,
        max_tree_depth=max_tree_depth,
    )

mcmc = MCMC(
        nuts_kernel,
        num_samples=num_samples,
        num_warmup=num_warmup,
        num_chains=num_chains,
        chain_method='parallel',
    )

 mcmc.run(rng_key)

has very different runtimes depending on the number of chains I request. For a model that I’m running locally, with num_chains=1, the model takes 2 hours for warmup and 5 hours to sample. If I instead set num_chains=4, the model takes 5 hours for warmup and 12 hours to sample. This is >2x slower!

I would have expected that we should be able to run 4 chains at once in parallel without incurring much of a cost in terms of speed.

Is this behaviour expected?

I had more of a dig around, and found out that one of the chains was using a larger treedepth than the others (3 chains had treedepth of 11, one with 12). This makes me think that I probably need to increase the number of tuning samples.

1 Like

I guess one of the chains is having a smaller step size than the others (assuming that the posterior is not multi-modal). NUTS is not so sensitive to step_size, so you can use a constant step size across all chains (by setting adapt_step_size=False).

1 Like

please also keep in mind that num_chains > 1 is provided as a convenience method. if you’re seeing actual computational slow down (and not just variable behavior between chains) you can try running independent processes with num_chains=1. in that case you’ll probably want to aggregate samples yourself and use utility functions like print_summary to diagnose chain convergence, paying particular attention to inter-chain convergence.

In case anybody else has this problem, @fehiepsi was correct in that some chains were slower than others. One chain had a larger treedepth than the others, so it was ‘bottlenecking’ the computation.

I resolved this issue by increasing the number of warmup samples, so all of the chains adapted similarly.

1 Like