MCMC with num_chains>1 stuck on winapi.WaitForMultipleObjects

Hi everyone,

I’m having issues when I try to run multiple Markov chains. I see that the problem has already been reported on GitHub, but it’s more than 2 years old and I don’t see any updates.

Here’s an MWE:

import pyro
import torch
import pyro.distributions as dist
from pyro.infer import MCMC, NUTS

# create some data with 6 observed heads and 4 observed tails
data = []
for _ in range(6):
for _ in range(4):
data = torch.tensor(data)

# define model
def model(data):
    # define the hyperparameters that control the Beta prior
    alpha0 = torch.tensor(10.0)
    beta0 = torch.tensor(10.0)
    # sample f from the Beta prior
    f = pyro.sample("latent_fairness", dist.Beta(alpha0, beta0))
    # vectorized plate over the observed data
    with pyro.plate('observe_data'):
        # likelihood Bernoulli(f)
        pyro.sample('obs', dist.Bernoulli(f), obs=data)

# run MCMC
nuts_kernel = NUTS(model)

mcmc = MCMC(nuts_kernel, num_samples=50, warmup_steps=10, num_chains=2)

When I run it, it shows this with no progress:


If I kill it, it shows where it was stuck:

KeyboardInterrupt                         Traceback (most recent call last)
c:\PATH\ in <module>
     33 mcmc = MCMC(nuts_kernel, num_samples=50, warmup_steps=10, num_chains=2)
---> 34

c:\PATH\lib\site-packages\pyro\poutine\ in _context_wrap(context, fn, *args, **kwargs)
     10 def _context_wrap(context, fn, *args, **kwargs):
     11     with context:
---> 12         return fn(*args, **kwargs)

c:\PATH\lib\site-packages\pyro\infer\mcmc\ in run(self, *args, **kwargs)
    561             # requires_grad", which happens with `jit_compile` under PyTorch 1.7
    562             args = [arg.detach() if torch.is_tensor(arg) else arg for arg in args]
--> 563             for x, chain_id in*args, **kwargs):
    564                 if num_samples[chain_id] == 0:
    565                     num_samples[chain_id] += 1

c:\PATH\lib\site-packages\pyro\infer\mcmc\ in run(self, *args, **kwargs)
    334             while active_workers:
    335                 try:
--> 336                     chain_id, val = self.result_queue.get(timeout=5)
--> 816             res = _winapi.WaitForMultipleObjects(L, False, timeout)
    817             if res == WAIT_TIMEOUT:
    818                 break


Do you know if there is a solution? Or should I just switch to Linux?