GPU available but slow compared to CPU on BNN example

Hello,
I was trying to do the example http://num.pyro.ai/en/stable/examples/bnn.html

But, with a CPU I get 10x more iterations per second than on a V100 GPU !

I have checked that the GPU is seen by the script

print(jax.devices())
[GpuDevice(id=0, process_index=0)]

What could be the source of inefficiency? is there any enrionment variable to x-check? Thanks

hi @campagne

i believe you would have easily found an answer to this question if you had searched previous forum posts. that is one of the main purposes of the forum.

GPU workloads are generally only faster than CPU workloads when the underlying tensor operations are sufficiently large. this is basically because GPU use incurs additional overhead.

so this behavior is expected.

1 Like

Thanks @martinjankowiak
I was asking as in an other case where the tensors are large and the GPU was also slower than CPU so people in charge of the GPU farm are digging the reason…

@campagne
What type of approximate inference algorithm you are using? It’s quite common for MCMC to be slower on GPU than CPU.

If you are using VI, then there could be some real underlying issues.

Well, It is a VI + NeuraReparametrisation to perform a NUTS. Currently, we have found that the job runs 100% on CPU while GPU memory is activated. it’s as if the CPU was going back and forth with the GPU just to access the CPU’s ram, but the computation instructions on the GPU were not done.

Here are the conda list of packages

# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                        main
_openmp_mutex             4.5                       1_gnu
arviz                     0.11.4                   pypi_0    pypi
astropy                   4.3.1                    pypi_0    pypi
backcall                  0.2.0                    pypi_0    pypi
ca-certificates           2021.10.26           h06a4308_2
certifi                   2021.10.8        py38h06a4308_0
cftime                    1.5.1.1                  pypi_0    pypi
chex                      0.0.8                    pypi_0    pypi
corner                    2.2.1                    pypi_0    pypi
cycler                    0.11.0                   pypi_0    pypi
decorator                 5.1.0                    pypi_0    pypi
ipython                   7.29.0                   pypi_0    pypi
jax                       0.2.24                   pypi_0    pypi
jaxlib                    0.1.73+cuda11.cudnn82          pypi_0    pypi
jedi                      0.18.0                   pypi_0    pypi
kiwisolver                1.3.2                    pypi_0    pypi
ld_impl_linux-64          2.35.1               h7274673_9
libffi                    3.3                  he6710b0_2
libgcc-ng                 9.3.0               h5101ec6_17
libgomp                   9.3.0               h5101ec6_17
libstdcxx-ng              9.3.0               hd4cf53a_17
matplotlib                3.4.3                    pypi_0    pypi
matplotlib-inline         0.1.3                    pypi_0    pypi
ncurses                   6.3                  heee7806_1
netcdf4                   1.5.8                    pypi_0    pypi
numpy                     1.21.4                   pypi_0    pypi
numpyro                   0.8.0                     dev_0    <develop>
openssl                   1.1.1l               h7f8727e_0
optax                     0.0.9                    pypi_0    pypi
packaging                 21.2                     pypi_0    pypi
pandas                    1.3.4                    pypi_0    pypi
parso                     0.8.2                    pypi_0    pypi
pexpect                   4.8.0                    pypi_0    pypi
pickleshare               0.7.5                    pypi_0    pypi
pillow                    8.4.0                    pypi_0    pypi
pip                       21.2.4           py38h06a4308_0
prompt-toolkit            3.0.22                   pypi_0    pypi
ptyprocess                0.7.0                    pypi_0    pypi
pyerfa                    2.0.0.1                  pypi_0    pypi
pygments                  2.10.0                   pypi_0    pypi
pyparsing                 2.4.7                    pypi_0    pypi
python                    3.8.12               h12debd9_0
python-dateutil           2.8.2                    pypi_0    pypi
pytz                      2021.3                   pypi_0    pypi
readline                  8.1                  h27cfd23_0
scipy                     1.7.2                    pypi_0    pypi
setuptools                58.0.4           py38h06a4308_0
six                       1.16.0                   pypi_0    pypi
sqlite                    3.36.0               hc218d9a_0
tk                        8.6.11               h1ccaba5_0
toolz                     0.11.2                   pypi_0    pypi
tqdm                      4.62.3                   pypi_0    pypi
traitlets                 5.1.1                    pypi_0    pypi
typing-extensions         3.10.0.2                 pypi_0    pypi
wcwidth                   0.2.5                    pypi_0    pypi
wheel                     0.37.0             pyhd3eb1b0_1
xarray                    0.20.1                   pypi_0    pypi
xz                        5.2.5                h7b6447c_0
zlib                      1.2.11               h7b6447c_3

Notice that I have cloned numpyro just to make a x-check on transform.py (diag = jnp.clip(diag, a_min=1e-12)) but it is commented. I wander if cloning Numpyro is a possible explanation of GPU desactivation?