CPU Usage with NUTS

Hi,
I am using NUTS to fit a multivariate normal model (dim=12) on a dataset. I noticed that when running NUTS, the cpu usage is very high (600% cpu) with one chain. I increased the number of chains to six, the cpu usage decreased (to 130% cpu), however, the running time had a significant increase.
I thought the reason might be because of the dataset. So I run the lkj.py in pyro/examples folder. The same problem occurred.


ps the pyro version I use is 1.3.1
I really appreciate any help!

Iā€™m seeing a similar behavior

  1. Full usage of all cores when num_chains=1
  2. Very slow sampling if num_chains > 1
  3. I see this issue on my Ubuntu machine
  4. I do not see this issue on my mac book pro (even when running in an Ubuntu docker image)

Ubuntu

System: Ubuntu 18.04, CPU: i7 (6 cores 12 threads)
All data fits into RAM so this is not an I/O issue

Setting the num_chains=1

mcmc = MCMC(kernel, num_samples=10_000, warmup_steps=0, num_chains=1)

I see that almost all 12 threads on my machine are in use at 100%


Running under in this condition runs at ~200 it/s

Sample: 100%|ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 10000/10000 [00:50, 198.61it/s, step size=2.50e-01, acc. prob=0.009]

Setting the num_chains=2

If I increase num_chains to 2, I see full usage of all 12 threads.

and my sampling speed plummets (note the number is now in s/it not it/s)

Warmup [1]:   0%|                                      | 8/10000 [00:48,  7.62s/it, step size=3.12e-02, acc. prob=0.993]
Warmup [2]:   0%|                                      | 3/10000 [00:18,  5.66s/it, step size=1.56e-02, acc. prob=0.999]

OSX

Oddly, If I run this on my MBP, I see no slowdown when going to more chains

System: Mac OSX, CPU i7 quad-core
All data fits in RAM

One chain

Sample:   0%|ā–                                        | 33/10000 [00:13,  2.28it/s, step size=3.12e-02, acc. prob=0.896]

Two chains

Warmup [1]:   2%|ā–‹                                   | 192/10000 [00:55,  1.13it/s, step size=6.25e-02, acc. prob=0.274]
Warmup [2]:  12%|ā–ˆā–ˆā–ˆā–ˆā–Ž                              | 1235/10000 [00:55, 27.41it/s, step size=1.25e-01, acc. prob=0.048]

To make matters stranger, if I start an Ubuntu docker container on my Mac, the numbers are the same as on OSX. (Running this same Ubuntu image on my ubuntu machine results in the same performance hit as at the start of the post)

Environment is consistent between the two machines

Package          Version
---------------- -------
numpy            1.18.4
pyro-api         0.1.2
pyro-ppl         1.3.1
torch            1.5.0
torchvision      0.6.0

Thanks for raising this and providing some details. Could you post a code snippet that we can run and replicate this?

Also, if you run python examples/baseball.py (which should start off 4 chains by default) do you notice the same issue?

1 Like

I think this is an expected behavior with PyTorch >= 1.1 (see this comment). You can set the environment flag OMP_NUM_THREADS=1 to get the expected behavior.

1 Like

I apologize for the late reply, google seems to like placing pyro forum emails in my junk :roll_eyes:

Having

import os
os.environ["OMP_NUM_THREADS"] = '1'

at the top of my script seemed to fix this