SVI+Neura/NUTS: 1thread/40 uses 100%CPU and the others 0%, and 60% GPU used only


I am running a CPU (Linux) + GPU-V100 hardware to perform SVI+NeuraTransform->MCMC-NUTS

Well, what is puzziling me is that

  1. the SVI uses a lot of threads (~40) which non of them uses the CPU while a thread is using 100% of the CPU and little GPU
  2. For the Neutra it is the same in terms of CPU & thread, while the GPU is used a 60% or so.

I wander if there are something that I have missed to make the CPU threads working and the GPU more efficiently used.

Here are some infos on my packages and some snippets

Now my Numpyro model looks schematically like

def  model_spl(obs=None):
   # define 16 priors
   # compute y=f(all_prior_variables)
  return  numpyro.sample('y', dist.MultivariateNormal(y, 

with P&C matrices of size 2500x2500 defined before and are constant in my modelling.

Now, I adjust an approximation of the posteriors via SVI using the following snippet (PHASE 1)

guide = autoguide.AutoMultivariateNormal(model_spl,

optimizer = numpyro.optim.Adam(5e-3)

svi = SVI(model_spl, guide,optimizer,loss=Trace_ELBO())

from functools import partial
def body_fn(i,carry):
    svi_state, svi_state_best, losses = carry
    svi_state, loss =svi.update(svi_state,cl_obs)

    def update_fn(x):
        return[i].set(loss), svi_state
    def keep_fn(x):
        return[i].set(losses[i-1]), svi_state_best
    losses, svi_state_best = jax.lax.cond(loss<losses[i-1],update_fn,keep_fn,None)
    return (svi_state, svi_state_best, losses)

svi_state = svi.init(jax.random.PRNGKey(42),cl_obs)

losses = np.zeros(num_steps)
losses =[0].set(1e10)
svi_state_best = svi_state
carry = (svi_state,svi_state_best,losses)
carry = jax.lax.fori_loop(1,num_steps,body_fn,carry)

print("Save loss")'loss_shear_nc_DEBUG.npy',carry[2]) #step decrease of the loss

Now, using NeuraParam I perform a MCMC/NUTS using

from numpyro.infer.reparam import NeuTraReparam
from numpyro.infer import MCMC, NUTS, init_to_sample
neutra = NeuTraReparam(guide, svi.get_params(carry[1])) 
neutra_model = neutra.reparam(model_spl)

nuts_kernel = NUTS(neutra_model,

mcmc_neutra = MCMC(nuts_kernel, num_warmup=1_000,

print('NUTS neutra'), cl_obs)

zs = mcmc_neutra.get_samples()["auto_shared_latent"]
samples_nuts_neutra = neutra.transform_sample(zs)

print("Save Neutra samples")'neutra_samples_shear_nc_dense_DEBUG.npy',samples_nuts_neutra)

Thanks for your attention.

Well, the GPU is still very little used and 1thread on CPU uses 100%… So I wander if I have missed something on the settings on Numpyro NUTS/MCMC???

You can check for active devices by using jax.local_devices. If you use GPU, I guess CPU is only used to transfer data from CPU to GPU (or maybe some control flow code in XLA that uses CPU, I’m not sure). You might use jax.device_put to move all data to GPU. I don’t think numpyro has any specific setting for this.

Well why there are 40 CPUs threads at 0% and a single at 100%. and the data are very little for obersvations 2500x2500 matrix…

why there are 40 CPUs threads at 0%

.Are you using CPU or GPU? If you are using GPU then this is expected.

a single at 100%

Probably this thread holds some communications with GPU, e.g. to print out the info for progress bar. I don’t know why it took 100% but to me, it is normal for a jax program.

Here is my snippet which is launched on a GPU (V100) after SVI has converged. It is the NUTS which I would like to speed up.

from numpyro.infer.reparam import NeuTraReparam
from numpyro.infer import MCMC, NUTS, init_to_sample
neutra = NeuTraReparam(guide, svi.get_params(carry[1])) #  svi_result.params
neutra_model = neutra.reparam(model_spl)

nuts_kernel = NUTS(neutra_model,
mcmc_neutra = MCMC(nuts_kernel, num_warmup=1_000,
                   num_chains=1,progress_bar=False), cl_obs)

Probably with neutra, it is enough to use init to uniform or init to feasible strategy. Otherwise, your code looks great.

Well even if it is very cool that my code looks great, in practice the job has not finished after 9 days of run while I have only 22 priors and a simple likehood with a single minimum…

Look what is the usage of CPU and GPU during the NUTS running of the neutra-reparam
