How to run more chains than number of devices in parallel?

Hi,

We have 8 devices and are trying to run 64 chains.

We get a warning,

UserWarning: There are not enough devices to run parallel chains: expected 64 but got 8. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(64)` at the beginning of your program. You can double-check how many devices are available in your system using `jax.local_device_count()`.

We assume that 8 chains will run in parallel, and as they complete, the next 8 chains will start, and so on until all chains are complete.

Clearly, we are wrong somewhere in our assumption, or there are additional settings to follow.

Reference code:

import jax.numpy as jnp
import jax.random as random
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS


# 1. Define the model
def model(x, y=None):
    # Priors on slope and intercept
    slope = numpyro.sample("slope", dist.Normal(0, 1))
    intercept = numpyro.sample("intercept", dist.Normal(0, 1))
    sigma = numpyro.sample("sigma", dist.Exponential(1.0))

    # Expected mean of y
    mean = intercept + slope * x
    # Likelihood
    numpyro.sample("obs", dist.Normal(mean, sigma), obs=y)


# 2. Generate synthetic data
key = random.PRNGKey(0)
true_slope, true_intercept, true_sigma = 2.0, -1.0, 0.5
x_data = jnp.linspace(-1, 1, 50)
y_data = (
    true_intercept + true_slope * x_data + true_sigma * random.normal(key, shape=(50,))
)

# 3. Run inference
nuts = NUTS(model)
mcmc = MCMC(
    nuts, num_warmup=500, num_samples=1000, num_chains=64, chain_method="parallel"
)
mcmc.run(key, x_data, y_data)

# 4. Inspect results
mcmc.print_summary()

# Posterior samples
posterior_samples = mcmc.get_samples()
print("Posterior slope mean:", posterior_samples["slope"].mean())
print("Posterior intercept mean:", posterior_samples["intercept"].mean())

you can call run+get_samples multiple times with different keys