MCMC spawn multi-processing: bad value(s) in fds_to_keep

I’m attempting to run the Bayesian regression tutorial below. I’m running this in a Jupyter notebook hosted on an AWS server. I’ve been trying to get MCMC to work when num_chains is greater than 1, on a CPU. I’ve hit the following roadblocks:

  • when mp_context is blank or “fork”, I receive RuntimeError: Unable to handle autograd's threading in combination with fork-based multiprocessing. See https://github.com/pytorch/pytorch/wiki/Autograd-and-Fork. This led me to try out tother mp_context options.
  • when mp_context is “forkserver”, four progress bars show up but none ever start. This isn’t just a progress bar issue, as setting num_samples and warmup steps to 1 doesn’t cause them to finish, either.
  • when mp_context is “spawn”, I receive ValueError: bad value(s) in fds_to_keep.

My sense is that mp_context should be “spawn” in this environment, but I don’t know how to address the fds_to_keep error – some googling has shown that this occurs sometimes in PyTorch, but I’m not sure how to solve it for the Pyro use case.

import numpy as np
import pandas as pd
import torch

import pyro
import pyro.distributions as dist
from pyro.infer.mcmc import MCMC, NUTS

pyro.enable_validation(True)
pyro.set_rng_seed(1)
DATA_URL = "https://d2hg8soec8ck9v.cloudfront.net/datasets/rugged_data.csv"
rugged_data = pd.read_csv(DATA_URL, encoding="ISO-8859-1")

def model(is_cont_africa, ruggedness, log_gdp):
    a = pyro.sample("a", dist.Normal(8., 1000.))
    b_a = pyro.sample("bA", dist.Normal(0., 1.))
    b_r = pyro.sample("bR", dist.Normal(0., 1.))
    b_ar = pyro.sample("bAR", dist.Normal(0., 1.))
    sigma = pyro.sample("sigma", dist.Uniform(0., 10.))
    mean = a + b_a * is_cont_africa + b_r * ruggedness + b_ar * is_cont_africa * ruggedness
    with pyro.iarange("data", len(ruggedness)):
        pyro.sample("obs", dist.Normal(mean, sigma), obs=log_gdp)

df = rugged_data[["cont_africa", "rugged", "rgdppc_2000"]]
df = df[np.isfinite(df.rgdppc_2000)]
df["rgdppc_2000"] = np.log(df["rgdppc_2000"])
train = torch.tensor(df.values, dtype=torch.float)
is_cont_africa, ruggedness, log_gdp = train[:, 0], train[:, 1], train[:, 2]

nuts_kernel = NUTS(model, adapt_step_size=True)
hmc_posterior = MCMC(nuts_kernel, num_samples=4000, warmup_steps=1000, num_chains=4, mp_context = "spawn").run(is_cont_africa, ruggedness, log_gdp)

Were you able to reproduce this error locally on a Mac or Linux machine? If not, it’s possible that this is a PyTorch issue whose resolution depends on the details of your AWS instance and image, so you might get better answers on the PyTorch forum, especially if you can reproduce the error in a pure PyTorch example.

On my local mac environment,

  • mp_context = ‘fork’ --> completes fine
  • mp_context = ‘forkserver’ --> progress bars show up but never start
  • mp_context = ‘spawn’ --> progress bars show up but never start

So that means that ‘forkserver’ hasn’t worked in either environment, and ‘spawn’ may be an AWS instance-specific problem that I’ll need to dig into. Do you know whether ‘forkserver’ is supposed to work for Pyro multiprocessing?

I think using spawn method in notebook is not supported in PyTorch. We have tracked one of the issues here but currently, there is no progress on that. Beside torch.multiprocessing.set_start_method, there are several other problems with MCMC multi-chains:

  • behavior in linux is different from the behavior in mac
  • sometime we need to change the sharing strategy using torch.multiprocessing.set_sharing_strategy
  • in some systems, running multi-chains in CPU get stuck if pytorch gpu version is installed

I haven’t kept up with the current progress of multiprocessing in PyTorch so hopefully, those issues can be resolved with a few of changes. We appreciate any help to make multi-chain MCMC works seemlessly, across windows/linux/mac, cpu/gpu, script/notebook. Currently, we have problems for each of those environments.

1 Like

Thank you @fehiepsi!

It would be nice to allow num_chains > 1 even when multiprocessing isn’t possible. If my understanding is correct, sampling multiple chains is beneficial for MCMC even when there is no performance/multiprocessing gain, as it reduces reliance on the initialization of the parameters, particularly for complex distributions. So even if you can’t parallelize the chains, it would be nice to be able to run multiple chains.

Based only on a brief examination of the code here, I think this would be possible by optionally entering parallel = False or max_cpus = 1 as a parameter in MCMC.__init__(), to enforce drawing chains sequentially.

Am I thinking through that correctly?

As a workaround, I’m using something like this to manually draw chains sequentially:

all_samples_list = []
for i in range(4):
    nuts_kernel = NUTS(model, adapt_step_size=True)
    hmc_posterior = MCMC(nuts_kernel, num_samples=100, warmup_steps=1, num_chains=1)
    hmc_posterior.run(is_cont_africa, ruggedness, log_gdp)
    all_samples_list.append(hmc_posterior.get_samples())
    
all_samples_flat = {}
for k in all_samples_list[0].keys():
    all_samples_flat[k] = torch.tensor([])
for d in all_samples_list:
    for k, v in d.items():
        all_samples_flat[k] = torch.cat([all_samples_flat[k], v], dim = 0)
        
all_samples_flat

@neerajprad Should we add an arg chain_method as in NumPyro to let users draw chains sequentially?

@rsyoung Would you mind creating a feature request in Pyro, so we can follow up on this enhancement? Thanks!

Created a feature request here.

1 Like

Found a workaround using joblib that works on my AWS instance and locally. Just have to make sure not to set num_threads too high or I run out of memory.

from joblib import Parallel, delayed
def get_samples_chain(num_samples, warmup_steps):
    nuts_kernel = NUTS(model, adapt_step_size = True)
    hmc_posterior = MCMC(nuts_kernel, num_samples = num_samples, warmup_steps = warmup_steps, num_chains=1)
    hmc_posterior.run(is_cont_africa, ruggedness, log_gdp)
    return hmc_posterior.get_samples()
    
num_chains = 10
num_threads = 5
all_samples_list = Parallel(n_jobs = num_threads)(
    delayed(get_samples_chain)(100, 10) for i in range(num_chains)
)
    
all_samples_flat = {
    k : torch.cat([sample_set[k] for sample_set in all_samples_list], dim = 0) for k in all_samples_list[0].keys()
}
        
all_samples_flat
1 Like