Hi,
It is my understanding that if a NumPyro model starts running, it should use a GPU most of the time. However, contrary to my knowledge, it is using a CPU beyond expectation (can be seen in the screenshot). This unexpected behaviour goes throughout the run.
I am using the standard way to initialize the guide. The data is transferred to the device using jax.device_put. All heavy computations inside the guide are inside a jax.jit transform.
Any hints on what can be done to make it faster?
Thank you
Guide
# Copyright 2023 The GWKokab Authors
# SPDX-License-Identifier: Apache-2.0
from collections.abc import Callable
from typing import Dict, List, Optional
import equinox as eqx
import jax
import numpyro
from jax import Array, numpy as jnp
from jaxtyping import ArrayLike
from numpyro._typing import DistributionT
from numpyro.distributions import Distribution
from ..models.utils import JointDistribution, LazyJointDistribution, ScaledMixture
__all__ = ["numpyro_poisson_likelihood"]
def numpyro_poisson_likelihood(
dist_fn: Callable[..., DistributionT],
priors: JointDistribution,
variables: Dict[str, DistributionT],
variables_index: Dict[str, int],
log_constants: ArrayLike,
poisson_mean_estimator: Callable[[ScaledMixture], Array],
where_fns: Optional[List[Callable[..., Array]]],
constants: Dict[str, Array],
) -> Callable[[List[Array], List[Array], List[Array]], Array]:
is_lazy_prior = isinstance(priors, LazyJointDistribution)
if is_lazy_prior:
dependencies = priors.dependencies
partial_order = priors.partial_order
del priors
def log_likelihood_fn(*args: Array):
if is_lazy_prior:
partial_variables_samples = [
numpyro.sample(parameter_name, prior_dist)
if isinstance(prior_dist, Distribution)
else (parameter_name, prior_dist)
for parameter_name, prior_dist in sorted(
variables.items(), key=lambda x: x[0]
)
]
for i in partial_order:
kwargs = {
k: partial_variables_samples[v] for k, v in dependencies[i].items()
}
parameter_name, prior_dist_fn = partial_variables_samples[i]
if isinstance(prior_dist_fn, jax.tree_util.Partial):
prior_dist = prior_dist_fn.func(
*prior_dist_fn.args, **prior_dist_fn.keywords, **kwargs
) # type: ignore
partial_variables_samples[i] = numpyro.sample(
parameter_name, prior_dist
)
variables_samples = partial_variables_samples # type: ignore
else:
variables_samples = [
numpyro.sample(parameter_name, prior_dist)
for parameter_name, prior_dist in sorted(
variables.items(), key=lambda x: x[0]
)
]
mapped_params = {
name: variables_samples[i] for name, i in variables_index.items()
}
model_instance: DistributionT = dist_fn(**mapped_params, validate_args=True)
# μ = E_{θ|Λ}[VT(θ)]
expected_rates = eqx.filter_jit(poisson_mean_estimator)(model_instance)
n_buckets = len(args) // 3
@jax.jit
def _total_log_likelihood_fn(*args: Array):
data_group = args[0:n_buckets]
log_ref_priors_group = args[n_buckets : 2 * n_buckets]
masks_group = args[2 * n_buckets : 3 * n_buckets]
total_log_likelihood = log_constants # - Σ log(M_i)
# Σ log Σ exp (log p(θ|data_n) - log π_n)
for batched_data, batched_log_ref_priors, batched_masks in zip(
data_group, log_ref_priors_group, masks_group
):
safe_data = jnp.where(
jnp.expand_dims(batched_masks, axis=-1),
batched_data,
model_instance.support.feasible_like(batched_data),
)
safe_log_ref_prior = jnp.where(
batched_masks, batched_log_ref_priors, 0.0
)
n_events_per_bucket, n_samples, _ = batched_data.shape
batched_model_log_prob = jax.vmap(
jax.vmap(model_instance.log_prob, axis_size=n_samples),
axis_size=n_events_per_bucket,
)(safe_data) # type: ignore
safe_model_log_prob = jnp.where(
batched_masks, batched_model_log_prob, -jnp.inf
)
batched_log_prob: Array = safe_model_log_prob - safe_log_ref_prior
batched_log_prob = jnp.where(
batched_masks & (~jnp.isnan(batched_log_prob)),
batched_log_prob,
-jnp.inf,
)
log_prob_sum = jax.nn.logsumexp(
batched_log_prob,
axis=-1,
where=~jnp.isneginf(batched_log_prob),
)
safe_log_prob_sum = jnp.where(
jnp.isneginf(log_prob_sum), -jnp.inf, log_prob_sum
)
total_log_likelihood += jnp.sum(safe_log_prob_sum, axis=-1)
if where_fns is not None and len(where_fns) > 0:
mask = where_fns[0](**constants, **mapped_params)
for where_fn in where_fns[1:]:
mask = mask & where_fn(**constants, **mapped_params)
total_log_likelihood = jnp.where(
mask,
total_log_likelihood,
-jnp.inf, # type: ignore
)
return total_log_likelihood
# - μ + Σ log Σ exp (log p(θ|data_n) - log π_n) - Σ log(M_i)
numpyro.factor(
"log_likelihood",
_total_log_likelihood_fn(*args) - expected_rates,
)
return log_likelihood_fn # type: ignore
Screenshot
Enviornment Info
>>> python -c "import jax; jax.print_environment_info()"
jax: 0.8.0
jaxlib: 0.8.0
numpy: 2.3.4
python: 3.13.9 | packaged by conda-forge | (main, Oct 22 2025, 23:33:35) [GCC 14.3.0]
device info: NVIDIA A100-SXM4-80GB-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='ldas-pcdev12', release='4.18.0-553.77.1.el8_10.x86_64', version='#1 SMP Fri Oct 3 14:30:23 UTC 2025', machine='x86_64')
XLA_PYTHON_CLIENT_PREALLOCATE=false
JAX_COMPILATION_CACHE_DIR=/home/muhammad.zeeshan/jax_cache
XLA_FLAGS=--xla_cpu_multi_thread_eigen=false
$ nvidia-smi
Tue Nov 4 13:12:50 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.95.05 Driver Version: 580.95.05 CUDA Version: 13.0 |
+-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA A100-SXM4-80GB Off | 00000000:01:00.0 Off | 0 |
| N/A 23C P0 67W / 500W | 1705MiB / 81920MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 1 NVIDIA A100-SXM4-80GB Off | 00000000:41:00.0 Off | 0 |
| N/A 44C P0 196W / 500W | 2067MiB / 81920MiB | 64% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 2 NVIDIA A100-SXM4-80GB Off | 00000000:81:00.0 Off | 0 |
| N/A 55C P0 204W / 500W | 2321MiB / 81920MiB | 67% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 3 NVIDIA A100-SXM4-80GB Off | 00000000:C1:00.0 Off | 0 |
| N/A 19C P0 55W / 500W | 0MiB / 81920MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 17683 C /opt/env/bin/python3 1694MiB |
| 1 N/A N/A 111112 C python 2056MiB |
| 2 N/A N/A 118927 C python 1880MiB |
| 2 N/A N/A 179584 C python 424MiB |
+-----------------------------------------------------------------------------------------+
>>> uv pip list
Using Python 3.13.9 environment at: /home/muhammad.zeeshan/.conda/envs/gwkenv
Package Version
------------------- -------------------
absl-py 2.3.1
arviz 0.22.0
astroplan 0.10.1
astropy 7.1.1
astropy-healpix 1.1.2
astropy-iers-data 0.2025.11.3.0.38.37
attrs 25.4.0
bilby 2.7.0
bilby-cython 0.5.3
certifi 2025.10.5
cffi 2.0.0
cfgv 3.4.0
charset-normalizer 3.4.4
chex 0.1.91
click 8.3.0
cloudpickle 3.1.2
colorspacious 1.1.2
contourpy 1.3.3
corner 2.2.3
crc32c 2.8
cryptography 46.0.3
cycler 0.12.1
dask 2025.10.0
dateparser 1.2.2
dill 0.4.0
distlib 0.4.0
donfig 0.8.1.post1
dqsegdb2 1.3.0
dynesty 3.0.0
emcee 3.1.6
equinox 0.13.2
filelock 3.20.0
flowmc 0.4.5
fonttools 4.60.1
fsspec 2025.10.0
glasbey 0.3.0
gwdatafind 2.1.1
gwkokab 0.2.0
gwosc 0.8.1
gwpy 3.0.13
h5netcdf 1.7.3
h5py 3.15.1
healpy 1.18.1
htcondor 25.3.1
identify 2.6.15
idna 3.11
igwn-auth-utils 1.4.0
igwn-ligolw 2.1.0
igwn-segments 2.1.0
jax 0.8.0
jax-cuda13-pjrt 0.8.0
jax-cuda13-plugin 0.8.0
jaxlib 0.8.0
jaxtyping 0.3.3
jenkspy 0.4.1
joblib 1.5.2
kiwisolver 1.4.9
lalsuite 7.26.1
ligo-gracedb 2.14.3
ligo-skymap 2.4.1
ligotimegps 2.1.0
llvmlite 0.45.1
locket 1.0.0
loguru 0.7.3
lscsoft-glue 4.1.1
matplotlib 3.10.7
ml-dtypes 0.5.3
mplcursors 0.7
multipledispatch 1.0.0
natsort 8.4.0
networkx 3.5
nodeenv 1.9.1
numba 0.62.1
numcodecs 0.16.3
numpy 2.3.4
numpyro 0.19.0
nvidia-cublas 13.1.0.3
nvidia-cuda-crt 13.0.88
nvidia-cuda-cupti 13.0.85
nvidia-cuda-nvcc 13.0.88
nvidia-cuda-nvrtc 13.0.88
nvidia-cuda-runtime 13.0.96
nvidia-cudnn-cu13 9.14.0.64
nvidia-cufft 12.0.0.61
nvidia-cusolver 12.0.4.66
nvidia-cusparse 12.6.3.3
nvidia-ml-py 13.580.82
nvidia-nccl-cu13 2.28.7
nvidia-nvjitlink 13.0.88
nvidia-nvshmem-cu13 3.4.5
nvidia-nvvm 13.0.88
nvitop 1.5.3
opt-einsum 3.4.0
optax 0.2.6
packaging 25.0
pandas 2.3.3
partd 1.4.2
pillow 12.0.0
pip 25.2
platformdirs 4.5.0
pre-commit 4.3.0
precession 2.1.1
psutil 7.1.3
ptemcee 1.0.0
pyavm 0.9.7
pycparser 2.23
pyerfa 2.0.1.5
pyjwt 2.10.1
pyparsing 3.2.5
python-dateutil 2.9.0.post0
pytz 2025.2
pyyaml 6.0.3
quadax 0.2.11
regex 2025.11.3
reproject 0.18.0
requests 2.32.5
rift 0.0.17.6
safe-netrc 1.0.1
scikit-learn 1.7.2
scipy 1.15.3
scitokens 1.8.1
seaborn 0.13.2
setuptools 80.9.0
shapely 2.1.2
six 1.17.0
threadpoolctl 3.6.0
toolz 1.1.0
tqdm 4.67.1
typing-extensions 4.15.0
tzdata 2025.2
tzlocal 5.3.1
urllib3 2.5.0
uv 0.9.7
virtualenv 20.35.4
wadler-lindig 0.1.7
xarray 2025.10.1
xarray-einstats 0.9.1
zarr 3.1.3
