Too much CPU utilization

Hi,

It is my understanding that if a NumPyro model starts running, it should use a GPU most of the time. However, contrary to my knowledge, it is using a CPU beyond expectation (can be seen in the screenshot). This unexpected behaviour goes throughout the run.

I am using the standard way to initialize the guide. The data is transferred to the device using jax.device_put. All heavy computations inside the guide are inside a jax.jit transform.

Any hints on what can be done to make it faster?

Thank you

Guide

# Copyright 2023 The GWKokab Authors
# SPDX-License-Identifier: Apache-2.0


from collections.abc import Callable
from typing import Dict, List, Optional

import equinox as eqx
import jax
import numpyro
from jax import Array, numpy as jnp
from jaxtyping import ArrayLike
from numpyro._typing import DistributionT
from numpyro.distributions import Distribution

from ..models.utils import JointDistribution, LazyJointDistribution, ScaledMixture


__all__ = ["numpyro_poisson_likelihood"]


def numpyro_poisson_likelihood(
    dist_fn: Callable[..., DistributionT],
    priors: JointDistribution,
    variables: Dict[str, DistributionT],
    variables_index: Dict[str, int],
    log_constants: ArrayLike,
    poisson_mean_estimator: Callable[[ScaledMixture], Array],
    where_fns: Optional[List[Callable[..., Array]]],
    constants: Dict[str, Array],
) -> Callable[[List[Array], List[Array], List[Array]], Array]:
    is_lazy_prior = isinstance(priors, LazyJointDistribution)
    if is_lazy_prior:
        dependencies = priors.dependencies
        partial_order = priors.partial_order
    del priors

    def log_likelihood_fn(*args: Array):
        if is_lazy_prior:
            partial_variables_samples = [
                numpyro.sample(parameter_name, prior_dist)
                if isinstance(prior_dist, Distribution)
                else (parameter_name, prior_dist)
                for parameter_name, prior_dist in sorted(
                    variables.items(), key=lambda x: x[0]
                )
            ]
            for i in partial_order:
                kwargs = {
                    k: partial_variables_samples[v] for k, v in dependencies[i].items()
                }
                parameter_name, prior_dist_fn = partial_variables_samples[i]
                if isinstance(prior_dist_fn, jax.tree_util.Partial):
                    prior_dist = prior_dist_fn.func(
                        *prior_dist_fn.args, **prior_dist_fn.keywords, **kwargs
                    )  # type: ignore
                partial_variables_samples[i] = numpyro.sample(
                    parameter_name, prior_dist
                )
            variables_samples = partial_variables_samples  # type: ignore
        else:
            variables_samples = [
                numpyro.sample(parameter_name, prior_dist)
                for parameter_name, prior_dist in sorted(
                    variables.items(), key=lambda x: x[0]
                )
            ]
        mapped_params = {
            name: variables_samples[i] for name, i in variables_index.items()
        }

        model_instance: DistributionT = dist_fn(**mapped_params, validate_args=True)

        # μ = E_{θ|Λ}[VT(θ)]
        expected_rates = eqx.filter_jit(poisson_mean_estimator)(model_instance)

        n_buckets = len(args) // 3

        @jax.jit
        def _total_log_likelihood_fn(*args: Array):
            data_group = args[0:n_buckets]
            log_ref_priors_group = args[n_buckets : 2 * n_buckets]
            masks_group = args[2 * n_buckets : 3 * n_buckets]

            total_log_likelihood = log_constants  # - Σ log(M_i)
            # Σ log Σ exp (log p(θ|data_n) - log π_n)
            for batched_data, batched_log_ref_priors, batched_masks in zip(
                data_group, log_ref_priors_group, masks_group
            ):
                safe_data = jnp.where(
                    jnp.expand_dims(batched_masks, axis=-1),
                    batched_data,
                    model_instance.support.feasible_like(batched_data),
                )
                safe_log_ref_prior = jnp.where(
                    batched_masks, batched_log_ref_priors, 0.0
                )

                n_events_per_bucket, n_samples, _ = batched_data.shape
                batched_model_log_prob = jax.vmap(
                    jax.vmap(model_instance.log_prob, axis_size=n_samples),
                    axis_size=n_events_per_bucket,
                )(safe_data)  # type: ignore
                safe_model_log_prob = jnp.where(
                    batched_masks, batched_model_log_prob, -jnp.inf
                )
                batched_log_prob: Array = safe_model_log_prob - safe_log_ref_prior
                batched_log_prob = jnp.where(
                    batched_masks & (~jnp.isnan(batched_log_prob)),
                    batched_log_prob,
                    -jnp.inf,
                )
                log_prob_sum = jax.nn.logsumexp(
                    batched_log_prob,
                    axis=-1,
                    where=~jnp.isneginf(batched_log_prob),
                )
                safe_log_prob_sum = jnp.where(
                    jnp.isneginf(log_prob_sum), -jnp.inf, log_prob_sum
                )
                total_log_likelihood += jnp.sum(safe_log_prob_sum, axis=-1)

            if where_fns is not None and len(where_fns) > 0:
                mask = where_fns[0](**constants, **mapped_params)
                for where_fn in where_fns[1:]:
                    mask = mask & where_fn(**constants, **mapped_params)
                total_log_likelihood = jnp.where(
                    mask,
                    total_log_likelihood,
                    -jnp.inf,  # type: ignore
                )
            return total_log_likelihood

        # - μ + Σ log Σ exp (log p(θ|data_n) - log π_n) - Σ log(M_i)
        numpyro.factor(
            "log_likelihood",
            _total_log_likelihood_fn(*args) - expected_rates,
        )

    return log_likelihood_fn  # type: ignore

Screenshot

Enviornment Info

>>> python -c "import jax; jax.print_environment_info()"

jax:    0.8.0
jaxlib: 0.8.0
numpy:  2.3.4
python: 3.13.9 | packaged by conda-forge | (main, Oct 22 2025, 23:33:35) [GCC 14.3.0]
device info: NVIDIA A100-SXM4-80GB-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='ldas-pcdev12', release='4.18.0-553.77.1.el8_10.x86_64', version='#1 SMP Fri Oct 3 14:30:23 UTC 2025', machine='x86_64')
XLA_PYTHON_CLIENT_PREALLOCATE=false
JAX_COMPILATION_CACHE_DIR=/home/muhammad.zeeshan/jax_cache
XLA_FLAGS=--xla_cpu_multi_thread_eigen=false

$ nvidia-smi
Tue Nov  4 13:12:50 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.95.05              Driver Version: 580.95.05      CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA A100-SXM4-80GB          Off |   00000000:01:00.0 Off |                    0 |
| N/A   23C    P0             67W /  500W |    1705MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA A100-SXM4-80GB          Off |   00000000:41:00.0 Off |                    0 |
| N/A   44C    P0            196W /  500W |    2067MiB /  81920MiB |     64%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   2  NVIDIA A100-SXM4-80GB          Off |   00000000:81:00.0 Off |                    0 |
| N/A   55C    P0            204W /  500W |    2321MiB /  81920MiB |     67%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   3  NVIDIA A100-SXM4-80GB          Off |   00000000:C1:00.0 Off |                    0 |
| N/A   19C    P0             55W /  500W |       0MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A           17683      C   /opt/env/bin/python3                   1694MiB |
|    1   N/A  N/A          111112      C   python                                 2056MiB |
|    2   N/A  N/A          118927      C   python                                 1880MiB |
|    2   N/A  N/A          179584      C   python                                  424MiB |
+-----------------------------------------------------------------------------------------+
>>> uv pip list

Using Python 3.13.9 environment at: /home/muhammad.zeeshan/.conda/envs/gwkenv
Package             Version
------------------- -------------------
absl-py             2.3.1
arviz               0.22.0
astroplan           0.10.1
astropy             7.1.1
astropy-healpix     1.1.2
astropy-iers-data   0.2025.11.3.0.38.37
attrs               25.4.0
bilby               2.7.0
bilby-cython        0.5.3
certifi             2025.10.5
cffi                2.0.0
cfgv                3.4.0
charset-normalizer  3.4.4
chex                0.1.91
click               8.3.0
cloudpickle         3.1.2
colorspacious       1.1.2
contourpy           1.3.3
corner              2.2.3
crc32c              2.8
cryptography        46.0.3
cycler              0.12.1
dask                2025.10.0
dateparser          1.2.2
dill                0.4.0
distlib             0.4.0
donfig              0.8.1.post1
dqsegdb2            1.3.0
dynesty             3.0.0
emcee               3.1.6
equinox             0.13.2
filelock            3.20.0
flowmc              0.4.5
fonttools           4.60.1
fsspec              2025.10.0
glasbey             0.3.0
gwdatafind          2.1.1
gwkokab             0.2.0
gwosc               0.8.1
gwpy                3.0.13
h5netcdf            1.7.3
h5py                3.15.1
healpy              1.18.1
htcondor            25.3.1
identify            2.6.15
idna                3.11
igwn-auth-utils     1.4.0
igwn-ligolw         2.1.0
igwn-segments       2.1.0
jax                 0.8.0
jax-cuda13-pjrt     0.8.0
jax-cuda13-plugin   0.8.0
jaxlib              0.8.0
jaxtyping           0.3.3
jenkspy             0.4.1
joblib              1.5.2
kiwisolver          1.4.9
lalsuite            7.26.1
ligo-gracedb        2.14.3
ligo-skymap         2.4.1
ligotimegps         2.1.0
llvmlite            0.45.1
locket              1.0.0
loguru              0.7.3
lscsoft-glue        4.1.1
matplotlib          3.10.7
ml-dtypes           0.5.3
mplcursors          0.7
multipledispatch    1.0.0
natsort             8.4.0
networkx            3.5
nodeenv             1.9.1
numba               0.62.1
numcodecs           0.16.3
numpy               2.3.4
numpyro             0.19.0
nvidia-cublas       13.1.0.3
nvidia-cuda-crt     13.0.88
nvidia-cuda-cupti   13.0.85
nvidia-cuda-nvcc    13.0.88
nvidia-cuda-nvrtc   13.0.88
nvidia-cuda-runtime 13.0.96
nvidia-cudnn-cu13   9.14.0.64
nvidia-cufft        12.0.0.61
nvidia-cusolver     12.0.4.66
nvidia-cusparse     12.6.3.3
nvidia-ml-py        13.580.82
nvidia-nccl-cu13    2.28.7
nvidia-nvjitlink    13.0.88
nvidia-nvshmem-cu13 3.4.5
nvidia-nvvm         13.0.88
nvitop              1.5.3
opt-einsum          3.4.0
optax               0.2.6
packaging           25.0
pandas              2.3.3
partd               1.4.2
pillow              12.0.0
pip                 25.2
platformdirs        4.5.0
pre-commit          4.3.0
precession          2.1.1
psutil              7.1.3
ptemcee             1.0.0
pyavm               0.9.7
pycparser           2.23
pyerfa              2.0.1.5
pyjwt               2.10.1
pyparsing           3.2.5
python-dateutil     2.9.0.post0
pytz                2025.2
pyyaml              6.0.3
quadax              0.2.11
regex               2025.11.3
reproject           0.18.0
requests            2.32.5
rift                0.0.17.6
safe-netrc          1.0.1
scikit-learn        1.7.2
scipy               1.15.3
scitokens           1.8.1
seaborn             0.13.2
setuptools          80.9.0
shapely             2.1.2
six                 1.17.0
threadpoolctl       3.6.0
toolz               1.1.0
tqdm                4.67.1
typing-extensions   4.15.0
tzdata              2025.2
tzlocal             5.3.1
urllib3             2.5.0
uv                  0.9.7
virtualenv          20.35.4
wadler-lindig       0.1.7
xarray              2025.10.1
xarray-einstats     0.9.1
zarr                3.1.3