Error when sampling begins with multiple chains

I get errors when trying to run multiple chains in parallel using HMC, and it persists on multiple computers and across reboots. The error is

RuntimeError: received 0 items of ancdata.

See https://pastebin.com/bwApm5AY for the full error.

Any ideas as to what might be wrong? I do inference with

nuts_kernel = NUTS(model, max_tree_depth=5)
posterior = MCMC(nuts_kernel, num_samples=500, warmup_steps=2000, num_chains=5).run(x, y)

and my model looks like

def model(x, y):
    prior_std_mean = 1.0
    prior_std_var = 0.5

    fc1_mean_weight_prior = Normal(loc=torch.zeros_like(net.fc1_mean.weight), scale=prior_std_mean*torch.ones_like(net.fc1_mean.weight))
    fc1_mean_bias_prior = Normal(loc=torch.zeros_like(net.fc1_mean.bias), scale=prior_std_mean*torch.ones_like(net.fc1_mean.bias))

    fc2_mean_weight_prior = Normal(loc=torch.zeros_like(net.fc2_mean.weight), scale=prior_std_mean*torch.ones_like(net.fc2_mean.weight))
    fc2_mean_bias_prior = Normal(loc=torch.zeros_like(net.fc2_mean.bias), scale=prior_std_mean*torch.ones_like(net.fc2_mean.bias))

    fc1_var_weight_prior = Normal(loc=torch.zeros_like(net.fc1_var.weight), scale=prior_std_var*torch.ones_like(net.fc1_var.weight))
    fc1_var_bias_prior = Normal(loc=torch.zeros_like(net.fc1_var.bias), scale=prior_std_var*torch.ones_like(net.fc1_var.bias))

    fc2_var_weight_prior = Normal(loc=torch.zeros_like(net.fc2_var.weight), scale=prior_std_var*torch.ones_like(net.fc2_var.weight))
    fc2_var_bias_prior = Normal(loc=torch.zeros_like(net.fc2_var.bias), scale=prior_std_var*torch.ones_like(net.fc2_var.bias))

    priors = {"fc1_mean.weight": fc1_mean_weight_prior, "fc1_mean.bias": fc1_mean_bias_prior,
              "fc2_mean.weight": fc2_mean_weight_prior, "fc2_mean.bias": fc2_mean_bias_prior,
              "fc1_var.weight": fc1_var_weight_prior, "fc1_var.bias": fc1_var_bias_prior,
              "fc2_var.weight": fc2_var_weight_prior, "fc2_var.bias": fc2_var_bias_prior}
    
    lifted_module = pyro.random_module("module", net, priors)

    sampled_reg_model = lifted_module()
    
    mu, log_sigma_2 = sampled_reg_model(x)
    sigma = torch.sqrt(torch.exp(log_sigma_2))

    return pyro.sample("obs", pyro.distributions.Normal(mu, sigma), obs=y)

where net is just a regular NN with ReLUs on top of fully-connected layers.

Thank you!

Hi @jboyml, look like I also got the same problem as you. Could you try adding

torch.multiprocessing.set_sharing_strategy('file_system')

at the beginning of your script?

Thanks for your response. It now gets to the sampling phase, but my Jupyter kernel crashes during sampling with the following error five times:

KernelRestarter: restarting kernel (1/5), keep random ports
kernel 141fa090-50f7-4acd-85b0-8a50054fc956 restarted
Traceback (most recent call last):
File “/home/john/anaconda3/lib/python3.6/runpy.py”, line 193, in _run_module_as_main
main”, mod_spec)
File “/home/john/anaconda3/lib/python3.6/runpy.py”, line 85, in _run_code
exec(code, run_globals)
File “/home/john/anaconda3/lib/python3.6/site-packages/ipykernel_launcher.py”, line 16, in
app.launch_new_instance()
File “/home/john/anaconda3/lib/python3.6/site-packages/traitlets/config/application.py”, line 657, in launch_instance
app.initialize(argv)
File “”, line 2, in initialize
File “/home/john/anaconda3/lib/python3.6/site-packages/traitlets/config/application.py”, line 87, in catch_config_error
return method(app, *args, **kwargs)
File “/home/john/anaconda3/lib/python3.6/site-packages/ipykernel/kernelapp.py”, line 469, in initialize
self.init_sockets()
File “/home/john/anaconda3/lib/python3.6/site-packages/ipykernel/kernelapp.py”, line 238, in init_sockets
self.shell_port = self._bind_socket(self.shell_socket, self.shell_port)
File “/home/john/anaconda3/lib/python3.6/site-packages/ipykernel/kernelapp.py”, line 180, in _bind_socket
s.bind(“tcp://%s:%i” % (self.ip, port))
File “zmq/backend/cython/socket.pyx”, line 547, in zmq.backend.cython.socket.Socket.bind
File “zmq/backend/cython/checkrc.pxd”, line 25, in zmq.backend.cython.checkrc._check_rc
zmq.error.ZMQError: Address already in use

Just a heads up if you are using windows - multiprocessing support on windows is completely untested.

Are you able to run this from your terminal without using jupyter notebook?

I got the same error previously due to shared memory issue. Could you use the the pyro dev version and try the new interface instead (from pyro.infer.mcmc.api import MCMC) and replace

posterior = MCMC(nuts_kernel, num_samples=500, warmup_steps=2000,
                 num_chains=5).run(x, y)

by

mcmc = MCMC(nuts_kernel, num_samples=500, warmup_steps=2000, num_chains=5)
mcmc.run(x, y)
mcmc.get_samples()

I now receive the same error as originally, with

>>> pyro.__version__
'0.3.3+0113e0b3'

and

from pyro.infer.mcmc.api import MCMC, NUTS

It may be helpful to know that is only crashes when the number of steps gets large, e.g. I can do 50 steps with no problem, but 1000 warmup steps leads to crashing.

I also tried running it in a script (so not in a notebook) and the problem remains. I have Ubuntu 18.04.

@jboyml This is a long standing PyTorch issue https://github.com/pytorch/pytorch/issues/973 which I couldn’t find a good solution for it. Could you try to clean the folder /dev/shm before running the script ?

find /dev/shm -name torch* -delete

While running, you can use

watch "ls -1 /dev/shm | wc -l"

to see how many torch files created in that folder. Using new API and set_sharing_strategy('file_system') helps in my case but I think that they do not resolve the root problem. From various PyTorch topics, it seems that using Thread instead of Process might be helpful but I don’t have enough background to dive in that direction. Btw, could you make a full replicable script so we can dive into it again? I think that it might also help if we (optionally) add some checkpoints to consume samples and release shared resources of subprocesses. In the mean time, I’ll try to find if there is an easier-to-replicate script in my old notebooks. :slight_smile:

Edit: I can replicate the memory issue with the script in https://github.com/pyro-ppl/pyro/issues/1730 by setting high num_samples. I’ll dive into this again to see if there are better solutions.

1 Like

@jboyml Could you help me try the following solution: replace this line by

z_acc[k][chain_id].append(v.clone())
del v

@neerajprad I believe that if we do that way, these memory problems will be resolved (need @jboyml confirmation though) . I checked /dev/shm and the numbers were consistently low during mcmc run. And I can run the script https://github.com/pyro-ppl/pyro/issues/1730 with a large number of samples num_samples=10000 (4 chains). Because the clone operator can add some overheads, we might support a checkpoint arg which decides how many number of samples we should consume during sampling. E.g. checkpoint=1 is equivalent to the above suggestion

z_acc[k][chain_id].append(v.clone())
del v

while we can replace clone by torch.stack if checkpoint > 1. WDYT?

2 Likes

That seems to do the trick, thank you! Before the fix the number of files in /dev/shm increased rapidly once sampling started. After the fix, only 16 files per chain are added, and properly removed when finished.

1 Like