Hi,
We have 8 devices and are trying to run 64 chains.
We get a warning,
UserWarning: There are not enough devices to run parallel chains: expected 64 but got 8. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(64)` at the beginning of your program. You can double-check how many devices are available in your system using `jax.local_device_count()`.
We assume that 8 chains will run in parallel, and as they complete, the next 8 chains will start, and so on until all chains are complete.
Clearly, we are wrong somewhere in our assumption, or there are additional settings to follow.
Reference code:
import jax.numpy as jnp
import jax.random as random
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
# 1. Define the model
def model(x, y=None):
# Priors on slope and intercept
slope = numpyro.sample("slope", dist.Normal(0, 1))
intercept = numpyro.sample("intercept", dist.Normal(0, 1))
sigma = numpyro.sample("sigma", dist.Exponential(1.0))
# Expected mean of y
mean = intercept + slope * x
# Likelihood
numpyro.sample("obs", dist.Normal(mean, sigma), obs=y)
# 2. Generate synthetic data
key = random.PRNGKey(0)
true_slope, true_intercept, true_sigma = 2.0, -1.0, 0.5
x_data = jnp.linspace(-1, 1, 50)
y_data = (
true_intercept + true_slope * x_data + true_sigma * random.normal(key, shape=(50,))
)
# 3. Run inference
nuts = NUTS(model)
mcmc = MCMC(
nuts, num_warmup=500, num_samples=1000, num_chains=64, chain_method="parallel"
)
mcmc.run(key, x_data, y_data)
# 4. Inspect results
mcmc.print_summary()
# Posterior samples
posterior_samples = mcmc.get_samples()
print("Posterior slope mean:", posterior_samples["slope"].mean())
print("Posterior intercept mean:", posterior_samples["intercept"].mean())