I’ve been experimenting with the
parallel option for running multiple chains at once. There seems to be a slowdown when running multiple chains at the same time, which I wouldn’t have expected. I wanted to check if this behaviour was expected.
num_chains = 4 numpyro.set_host_device_count(num_chains) ... nuts_kernel = NUTS( some_model, init_strategy=init_to_median, target_accept_prob=target_accept, max_tree_depth=max_tree_depth, ) mcmc = MCMC( nuts_kernel, num_samples=num_samples, num_warmup=num_warmup, num_chains=num_chains, chain_method='parallel', ) mcmc.run(rng_key)
has very different runtimes depending on the number of chains I request. For a model that I’m running locally, with
num_chains=1, the model takes 2 hours for warmup and 5 hours to sample. If I instead set
num_chains=4, the model takes 5 hours for warmup and 12 hours to sample. This is >2x slower!
I would have expected that we should be able to run 4 chains at once in parallel without incurring much of a cost in terms of speed.
Is this behaviour expected?