MCMC with num_chains>1 stuck on winapi.WaitForMultipleObjects

Hi everyone,

I’m having issues when I try to run multiple Markov chains. I see that the problem has already been reported on GitHub, but it’s more than 2 years old and I don’t see any updates.

Here’s an MWE:

import pyro
import torch
import pyro.distributions as dist
from pyro.infer import MCMC, NUTS

# create some data with 6 observed heads and 4 observed tails
data = []
for _ in range(6):
    data.append(torch.tensor(1.0))
for _ in range(4):
    data.append(torch.tensor(0.0))
data = torch.tensor(data)

# define model
def model(data):
    # define the hyperparameters that control the Beta prior
    alpha0 = torch.tensor(10.0)
    beta0 = torch.tensor(10.0)
    # sample f from the Beta prior
    f = pyro.sample("latent_fairness", dist.Beta(alpha0, beta0))
    # vectorized plate over the observed data
    with pyro.plate('observe_data'):
        # likelihood Bernoulli(f)
        pyro.sample('obs', dist.Bernoulli(f), obs=data)


# run MCMC
nuts_kernel = NUTS(model)

mcmc = MCMC(nuts_kernel, num_samples=50, warmup_steps=10, num_chains=2)
mcmc.run(data)

When I run it, it shows this with no progress:

image

If I kill it, it shows where it was stuck:

KeyboardInterrupt                         Traceback (most recent call last)
c:\PATH\test.py in <module>
     32 
     33 mcmc = MCMC(nuts_kernel, num_samples=50, warmup_steps=10, num_chains=2)
---> 34 mcmc.run(data)

c:\PATH\lib\site-packages\pyro\poutine\messenger.py in _context_wrap(context, fn, *args, **kwargs)
     10 def _context_wrap(context, fn, *args, **kwargs):
     11     with context:
---> 12         return fn(*args, **kwargs)
     13 
     14 

c:\PATH\lib\site-packages\pyro\infer\mcmc\api.py in run(self, *args, **kwargs)
    561             # requires_grad", which happens with `jit_compile` under PyTorch 1.7
    562             args = [arg.detach() if torch.is_tensor(arg) else arg for arg in args]
--> 563             for x, chain_id in self.sampler.run(*args, **kwargs):
    564                 if num_samples[chain_id] == 0:
    565                     num_samples[chain_id] += 1

c:\PATH\lib\site-packages\pyro\infer\mcmc\api.py in run(self, *args, **kwargs)
    334             while active_workers:
    335                 try:
--> 336                     chain_id, val = self.result_queue.get(timeout=5)
...
--> 816             res = _winapi.WaitForMultipleObjects(L, False, timeout)
    817             if res == WAIT_TIMEOUT:
    818                 break

KeyboardInterrupt: 

Do you know if there is a solution? Or should I just switch to Linux?