Hi everyone,
I’m having issues when I try to run multiple Markov chains. I see that the problem has already been reported on GitHub, but it’s more than 2 years old and I don’t see any updates.
Here’s an MWE:
import pyro
import torch
import pyro.distributions as dist
from pyro.infer import MCMC, NUTS
# create some data with 6 observed heads and 4 observed tails
data = []
for _ in range(6):
data.append(torch.tensor(1.0))
for _ in range(4):
data.append(torch.tensor(0.0))
data = torch.tensor(data)
# define model
def model(data):
# define the hyperparameters that control the Beta prior
alpha0 = torch.tensor(10.0)
beta0 = torch.tensor(10.0)
# sample f from the Beta prior
f = pyro.sample("latent_fairness", dist.Beta(alpha0, beta0))
# vectorized plate over the observed data
with pyro.plate('observe_data'):
# likelihood Bernoulli(f)
pyro.sample('obs', dist.Bernoulli(f), obs=data)
# run MCMC
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_samples=50, warmup_steps=10, num_chains=2)
mcmc.run(data)
When I run it, it shows this with no progress:
If I kill it, it shows where it was stuck:
KeyboardInterrupt Traceback (most recent call last)
c:\PATH\test.py in <module>
32
33 mcmc = MCMC(nuts_kernel, num_samples=50, warmup_steps=10, num_chains=2)
---> 34 mcmc.run(data)
c:\PATH\lib\site-packages\pyro\poutine\messenger.py in _context_wrap(context, fn, *args, **kwargs)
10 def _context_wrap(context, fn, *args, **kwargs):
11 with context:
---> 12 return fn(*args, **kwargs)
13
14
c:\PATH\lib\site-packages\pyro\infer\mcmc\api.py in run(self, *args, **kwargs)
561 # requires_grad", which happens with `jit_compile` under PyTorch 1.7
562 args = [arg.detach() if torch.is_tensor(arg) else arg for arg in args]
--> 563 for x, chain_id in self.sampler.run(*args, **kwargs):
564 if num_samples[chain_id] == 0:
565 num_samples[chain_id] += 1
c:\PATH\lib\site-packages\pyro\infer\mcmc\api.py in run(self, *args, **kwargs)
334 while active_workers:
335 try:
--> 336 chain_id, val = self.result_queue.get(timeout=5)
...
--> 816 res = _winapi.WaitForMultipleObjects(L, False, timeout)
817 if res == WAIT_TIMEOUT:
818 break
KeyboardInterrupt:
Do you know if there is a solution? Or should I just switch to Linux?