SVI+Neura/NUTS: 1thread/40 uses 100%CPU and the others 0%, and 60% GPU used only

Hello,

I am running a CPU (Linux) + GPU-V100 hardware to perform SVI+NeuraTransform->MCMC-NUTS

Well, what is puzziling me is that

  1. the SVI uses a lot of threads (~40) which non of them uses the CPU while a thread is using 100% of the CPU and little GPU
  2. For the Neutra it is the same in terms of CPU & thread, while the GPU is used a 60% or so.

I wander if there are something that I have missed to make the CPU threads working and the GPU more efficiently used.

Here are some infos on my packages and some snippets

_libgcc_mutex             0.1                 conda_forge    conda-forge
_openmp_mutex             4.5                      1_llvm    conda-forge
absl-py                   1.0.0              pyhd8ed1ab_0    conda-forge
astropy                   4.3.1            py38h09021b7_0  
brotli                    1.0.9                he6710b0_2  
ca-certificates           2021.10.26           h06a4308_2  
certifi                   2021.10.8        py38h06a4308_0  
colorama                  0.4.4              pyh9f0ad1d_0    conda-forge
cudatoolkit               11.3.1               h2bc3f7f_2  
cudnn                     8.2.1                cuda11.3_0  
cycler                    0.10.0                   py38_0  
dbus                      1.13.18              hb2f20db_0  
expat                     2.4.1                h2531618_2  
fontconfig                2.13.1               h6c09931_0  
fonttools                 4.25.0             pyhd3eb1b0_0  
freetype                  2.11.0               h70c0345_0  
giflib                    5.2.1                h7b6447c_0  
glib                      2.69.1               h5202010_0  
gst-plugins-base          1.14.0               h8213a91_2  
gstreamer                 1.14.0               h28cd5cc_2  
icu                       58.2                 he6710b0_3  
jax                       0.2.25             pyhd8ed1ab_0    conda-forge
jaxlib                    0.1.73+cuda11.cudnn82          pypi_0    pypi
jpeg                      9d                   h7f8727e_0  
kiwisolver                1.3.1            py38h2531618_0  
lcms2                     2.12                 h3be6417_0  
ld_impl_linux-64          2.35.1               h7274673_9  
libblas                   3.9.0           12_linux64_openblas    conda-forge
libcblas                  3.9.0           12_linux64_openblas    conda-forge
libffi                    3.3                  he6710b0_2  
libgcc-ng                 11.2.0              h1d223b6_11    conda-forge
libgfortran-ng            11.2.0              h69a702a_11    conda-forge
libgfortran5              11.2.0              h5c6108e_11    conda-forge
liblapack                 3.9.0           12_linux64_openblas    conda-forge
libopenblas               0.3.18          pthreads_h8fe5266_0    conda-forge
libpng                    1.6.37               hbc83047_0  
libstdcxx-ng              11.2.0              he4da1e4_11    conda-forge
libtiff                   4.2.0                h85742a9_0  
libuuid                   1.0.3                h7f8727e_2  
libwebp                   1.2.0                h89dd481_0  
libwebp-base              1.2.0                h27cfd23_0  
libxcb                    1.14                 h7b6447c_0  
libxml2                   2.9.12               h03d6c58_0  
llvm-openmp               12.0.1               h4bd325d_1    conda-forge
lz4-c                     1.9.3                h295c915_1  
matplotlib                3.4.3            py38h06a4308_0  
matplotlib-base           3.4.3            py38hbbc1b5f_0  
munkres                   1.1.4                      py_0  
ncurses                   6.3                  h7f8727e_2  
numpy                     1.21.4           py38he2449b9_0    conda-forge
numpyro                   0.8.0              pyhd8ed1ab_0    conda-forge
olefile                   0.46               pyhd3eb1b0_0  
openssl                   1.1.1l               h7f8727e_0  
opt_einsum                3.3.0              pyhd8ed1ab_1    conda-forge
pcre                      8.45                 h295c915_0  
pillow                    8.4.0            py38h5aabda8_0  
pip                       21.2.4           py38h06a4308_0  
pyerfa                    2.0.0            py38h27cfd23_0  
pyparsing                 3.0.4              pyhd3eb1b0_0  
pyqt                      5.9.2            py38h05f1152_4  
python                    3.8.12               h12debd9_0  
python-dateutil           2.8.2              pyhd3eb1b0_0  
python-flatbuffers        2.0                pyhd8ed1ab_0    conda-forge
python_abi                3.8                      2_cp38    conda-forge
pyyaml                    6.0              py38h7f8727e_1  
qt                        5.9.7                h5867ecd_1  
readline                  8.1                  h27cfd23_0  
scipy                     1.7.3            py38h56a6a73_0    conda-forge
setuptools                58.0.4           py38h06a4308_0  
sip                       4.19.13          py38he6710b0_0  
six                       1.16.0             pyh6c4a22f_0    conda-forge
sqlite                    3.36.0               hc218d9a_0  
tk                        8.6.11               h1ccaba5_0  
tornado                   6.1              py38h27cfd23_0  
tqdm                      4.62.3             pyhd8ed1ab_0    conda-forge
typing_extensions         4.0.0              pyha770c72_0    conda-forge
wheel                     0.37.0             pyhd3eb1b0_1  
xz                        5.2.5                h7b6447c_0  
yaml                      0.2.5                h7b6447c_0  
zlib                      1.2.11               h7b6447c_3  
zstd                      1.4.9                haebb681_0  

Now my Numpyro model looks schematically like

def  model_spl(obs=None):
   # define 16 priors
   # compute y=f(all_prior_variables)
  return  numpyro.sample('y', dist.MultivariateNormal(y, 
                                                        precision_matrix=P,
                                                        covariance_matrix=C),
                          obs=obs)

with P&C matrices of size 2500x2500 defined before and are constant in my modelling.

Now, I adjust an approximation of the posteriors via SVI using the following snippet (PHASE 1)

guide = autoguide.AutoMultivariateNormal(model_spl,
                                         init_loc_fn=numpyro.infer.init_to_value(values=true_param))                                         

optimizer = numpyro.optim.Adam(5e-3)

svi = SVI(model_spl, guide,optimizer,loss=Trace_ELBO())

from functools import partial
def body_fn(i,carry):
    svi_state, svi_state_best, losses = carry
    svi_state, loss =svi.update(svi_state,cl_obs)

    def update_fn(x):
        return losses.at[i].set(loss), svi_state
    def keep_fn(x):
        return losses.at[i].set(losses[i-1]), svi_state_best
    losses, svi_state_best = jax.lax.cond(loss<losses[i-1],update_fn,keep_fn,None)
    return (svi_state, svi_state_best, losses)


svi_state = svi.init(jax.random.PRNGKey(42),cl_obs)

num_steps=5_000 
losses = np.zeros(num_steps)
losses = losses.at[0].set(1e10)
svi_state_best = svi_state
carry = (svi_state,svi_state_best,losses)
carry = jax.lax.fori_loop(1,num_steps,body_fn,carry)

print("Save loss")
np.save('loss_shear_nc_DEBUG.npy',carry[2]) #step decrease of the loss

Now, using NeuraParam I perform a MCMC/NUTS using

from numpyro.infer.reparam import NeuTraReparam
from numpyro.infer import MCMC, NUTS, init_to_sample
neutra = NeuTraReparam(guide, svi.get_params(carry[1])) 
neutra_model = neutra.reparam(model_spl)

nuts_kernel = NUTS(neutra_model,
                  init_strategy=numpyro.infer.init_to_median(),
                   dense_mass=True,
                   max_tree_depth=5)

mcmc_neutra = MCMC(nuts_kernel, num_warmup=1_000,
                   num_samples=8_000,  
                   num_chains=1,progress_bar=False)

print('NUTS neutra')
mcmc_neutra.run(jax.random.PRNGKey(42), cl_obs)

mcmc_neutra.print_summary()
zs = mcmc_neutra.get_samples()["auto_shared_latent"]
samples_nuts_neutra = neutra.transform_sample(zs)

print("Save Neutra samples")
np.save('neutra_samples_shear_nc_dense_DEBUG.npy',samples_nuts_neutra)

Thanks for your attention.

Well, the GPU is still very little used and 1thread on CPU uses 100%… So I wander if I have missed something on the settings on Numpyro NUTS/MCMC???

You can check for active devices by using jax.local_devices. If you use GPU, I guess CPU is only used to transfer data from CPU to GPU (or maybe some control flow code in XLA that uses CPU, I’m not sure). You might use jax.device_put to move all data to GPU. I don’t think numpyro has any specific setting for this.

Well why there are 40 CPUs threads at 0% and a single at 100%. and the data are very little for obersvations 2500x2500 matrix…

why there are 40 CPUs threads at 0%

.Are you using CPU or GPU? If you are using GPU then this is expected.

a single at 100%

Probably this thread holds some communications with GPU, e.g. to print out the info for progress bar. I don’t know why it took 100% but to me, it is normal for a jax program.

Here is my snippet which is launched on a GPU (V100) after SVI has converged. It is the NUTS which I would like to speed up.

from numpyro.infer.reparam import NeuTraReparam
from numpyro.infer import MCMC, NUTS, init_to_sample
neutra = NeuTraReparam(guide, svi.get_params(carry[1])) #  svi_result.params
neutra_model = neutra.reparam(model_spl)

nuts_kernel = NUTS(neutra_model,
                  init_strategy=numpyro.infer.init_to_median(),
                   dense_mass=True,
                   max_tree_depth=5)
mcmc_neutra = MCMC(nuts_kernel, num_warmup=1_000,
                   num_samples=8_000,  
                   num_chains=1,progress_bar=False)

mcmc_neutra.run(jax.random.PRNGKey(42), cl_obs)

Probably with neutra, it is enough to use init to uniform or init to feasible strategy. Otherwise, your code looks great.

Well even if it is very cool that my code looks great, in practice the job has not finished after 9 days of run while I have only 22 priors and a simple likehood with a single minimum…

Look what is the usage of CPU and GPU during the NUTS running of the neutra-reparam


image