from jax import config
config.update("jax_enable_x64", True)
from jax import random
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
import celerite2.jax
from celerite2.jax import terms as jax_terms
import numpy as np
np.random.seed(42)
t = np.sort(
np.append(
np.random.uniform(0, 3.8, 57),
np.random.uniform(5.5, 10, 68),
)
) # The input coordinates must be sorted
yerr = np.random.uniform(0.08, 0.22, len(t))
y = (
0.2 * (t - 5)
+ np.sin(3 * t + 0.1 * (t - 5) ** 2)
+ yerr * np.random.randn(len(t))
)
true_t = np.linspace(0, 10, 500)
true_y = 0.2 * (true_t - 5) + np.sin(3 * true_t + 0.1 * (true_t - 5) ** 2)
prior_sigma = 2.0
def numpyro_model(t, yerr, y=None):
mean = numpyro.sample("mean", dist.Normal(0.0, prior_sigma))
log_jitter = numpyro.sample("log_jitter", dist.Normal(0.0, prior_sigma))
log_sigma1 = numpyro.sample("log_sigma1", dist.Normal(0.0, prior_sigma))
log_rho1 = numpyro.sample("log_rho1", dist.Normal(0.0, prior_sigma))
log_tau = numpyro.sample("log_tau", dist.Normal(0.0, prior_sigma))
term1 = jax_terms.SHOTerm(
sigma=jnp.exp(log_sigma1), rho=jnp.exp(log_rho1), tau=jnp.exp(log_tau)
)
log_sigma2 = numpyro.sample("log_sigma2", dist.Normal(0.0, prior_sigma))
log_rho2 = numpyro.sample("log_rho2", dist.Normal(0.0, prior_sigma))
term2 = jax_terms.SHOTerm(
sigma=jnp.exp(log_sigma2), rho=jnp.exp(log_rho2), Q=0.25
)
kernel = term1 + term2
gp = celerite2.jax.GaussianProcess(kernel, mean=mean)
gp.compute(t, diag=yerr**2 + jnp.exp(log_jitter), check_sorted=False)
numpyro.sample("obs", gp.numpyro_dist(), obs=y)
numpyro.deterministic("psd", kernel.get_psd(omega))
nuts_kernel = NUTS(numpyro_model, dense_mass=True)
mcmc = MCMC(
nuts_kernel,
num_warmup=1000,
num_samples=1000,
num_chains=2,
progress_bar=False,
)
rng_key = random.PRNGKey(34923)
mcmc.run(rng_key, t, yerr, y=y)
Running the above gives me the following error:
File "/computefs/scratch/username/mypackage/notebooks/temp/test.py", line 16, in <module>
import matplotlib.pyplot as plt
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/matplotlib/__init__.py", line 264, in <module>
_check_versions()
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/matplotlib/__init__.py", line 258, in _check_versions
module = importlib.import_module(modname)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/computefs/ixsoftware/python/3.12.6/install/lib/python3.12/importlib/__init__.py", line 90, in import_module
return _bootstrap._gcd_import(name[level:], package, level)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ModuleNotFoundError: No module named 'dateutil'
(.venv_cuda) [username@node10 temp]$ python test.py
Traceback (most recent call last):
File "/computefs/scratch/username/mypackage/notebooks/temp/test.py", line 16, in <module>
import matplotlib.pyplot as plt
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/matplotlib/__init__.py", line 264, in <module>
_check_versions()
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/matplotlib/__init__.py", line 258, in _check_versions
module = importlib.import_module(modname)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/computefs/ixsoftware/python/3.12.6/install/lib/python3.12/importlib/__init__.py", line 90, in import_module
return _bootstrap._gcd_import(name[level:], package, level)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ModuleNotFoundError: No module named 'dateutil'
(.venv_cuda) [username@node10 temp]$
(.venv_cuda) [username@node10 temp]$ python test.py
/computefs/scratch/username/mypackage/notebooks/temp/test.py:64: UserWarning: There are not enough devices to run parallel chains: expected 2 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(2)` at the beginning of your program. You can double-check how many devices are available in your system using `jax.local_device_count()`.
mcmc = MCMC(
Traceback (most recent call last):
File "/computefs/scratch/username/mypackage/notebooks/temp/test.py", line 72, in <module>
mcmc.run(rng_key, t, yerr, y=y)
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/infer/mcmc.py", line 706, in run
states, last_state = _laxmap(partial_map_fn, map_args)
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/infer/mcmc.py", line 177, in _laxmap
ys.append(f(x))
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/infer/mcmc.py", line 465, in _single_chain_mcmc
new_init_state = self.sampler.init(
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/infer/hmc.py", line 749, in init
init_params = self._init_state(
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/infer/hmc.py", line 693, in _init_state
) = initialize_model(
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/infer/util.py", line 688, in initialize_model
) = _get_model_transforms(substituted_model, model_args, model_kwargs)
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/infer/util.py", line 482, in _get_model_transforms
model_trace = trace(model).get_trace(*model_args, **model_kwargs)
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/handlers.py", line 191, in get_trace
self(*args, **kwargs)
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/primitives.py", line 121, in __call__
return self.fn(*args, **kwargs)
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/primitives.py", line 121, in __call__
return self.fn(*args, **kwargs)
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/primitives.py", line 121, in __call__
return self.fn(*args, **kwargs)
File "/computefs/scratch/username/mypackage/notebooks/temp/test.py", line 57, in numpyro_model
gp.compute(t, diag=yerr**2 + jnp.exp(log_jitter), check_sorted=False)
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/celerite2/core.py", line 317, in compute
self._do_compute(quiet)
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/celerite2/jax/celerite2.py", line 34, in _do_compute
self._d, self._W = ops.factor(
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/celerite2/jax/ops.py", line 39, in factor
d, W, S = factor_p.bind(t, c, a, U, V)
jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: MLIR translation rule for primitive 'celerite2_factor' not found for platform cuda
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/computefs/scratch/username/mypackage/notebooks/temp/test.py", line 72, in <module>
mcmc.run(rng_key, t, yerr, y=y)
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/infer/mcmc.py", line 706, in run
states, last_state = _laxmap(partial_map_fn, map_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/infer/mcmc.py", line 177, in _laxmap
ys.append(f(x))
^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/infer/mcmc.py", line 465, in _single_chain_mcmc
new_init_state = self.sampler.init(
^^^^^^^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/infer/hmc.py", line 749, in init
init_params = self._init_state(
^^^^^^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/infer/hmc.py", line 693, in _init_state
) = initialize_model(
^^^^^^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/infer/util.py", line 688, in initialize_model
) = _get_model_transforms(substituted_model, model_args, model_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/infer/util.py", line 482, in _get_model_transforms
model_trace = trace(model).get_trace(*model_args, **model_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/handlers.py", line 191, in get_trace
self(*args, **kwargs)
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/primitives.py", line 121, in __call__
return self.fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/primitives.py", line 121, in __call__
return self.fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/numpyro/primitives.py", line 121, in __call__
return self.fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/test.py", line 57, in numpyro_model
gp.compute(t, diag=yerr**2 + jnp.exp(log_jitter), check_sorted=False)
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/celerite2/core.py", line 317, in compute
self._do_compute(quiet)
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/celerite2/jax/celerite2.py", line 34, in _do_compute
self._d, self._W = ops.factor(
^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/celerite2/jax/ops.py", line 39, in factor
d, W, S = factor_p.bind(t, c, a, U, V)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/core.py", line 438, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/core.py", line 442, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/core.py", line 948, in process_primitive
return primitive.impl(*tracers, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/dispatch.py", line 90, in apply_primitive
outs = fun(*args)
^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/pjit.py", line 356, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
^^^^^^^^^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/pjit.py", line 189, in _python_pjit_helper
out_flat = pjit_p.bind(*args_flat, **p.params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/core.py", line 2781, in bind
return self.bind_with_trace(top_trace, args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/core.py", line 442, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/core.py", line 948, in process_primitive
return primitive.impl(*tracers, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/pjit.py", line 1764, in _pjit_call_impl
return xc._xla.pjit(
^^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/pjit.py", line 1739, in call_impl_cache_miss
out_flat, compiled = _pjit_call_impl_python(
^^^^^^^^^^^^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/pjit.py", line 1661, in _pjit_call_impl_python
compiled = _resolve_and_lower(
^^^^^^^^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/pjit.py", line 1628, in _resolve_and_lower
lowered = _pjit_lower(
^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/pjit.py", line 1780, in _pjit_lower
return _pjit_lower_cached(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/pjit.py", line 1801, in _pjit_lower_cached
return pxla.lower_sharding_computation(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/profiler.py", line 333, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py", line 2232, in lower_sharding_computation
nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo(
^^^^^^^^^^^^^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py", line 1952, in _cached_lowering_to_hlo
lowering_result = mlir.lower_jaxpr_to_module(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1152, in lower_jaxpr_to_module
lower_jaxpr_to_fun(
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1610, in lower_jaxpr_to_fun
out_vals, tokens_out = jaxpr_subcomp(
^^^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1825, in jaxpr_subcomp
ans = lower_per_platform(rule_ctx, str(eqn.primitive),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/computefs/scratch/username/mypackage/notebooks/temp/.venv_cuda/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1914, in lower_per_platform
raise NotImplementedError(
NotImplementedError: MLIR translation rule for primitive 'celerite2_factor' not found for platform cuda
celerite2==0.3.2
contourpy==1.3.1
cycler==0.12.1
fonttools==4.56.0
jax==0.4.34
jax-cuda12-pjrt==0.4.34
jax-cuda12-plugin==0.4.34
jaxlib==0.4.34
jaxopt==0.8.3
kiwisolver==1.4.8
matplotlib==3.10.0
ml_dtypes==0.5.1
multipledispatch==1.0.0
numpy==2.2.3
numpyro==0.17.0
nvidia-cublas-cu12==12.8.3.14
nvidia-cuda-cupti-cu12==12.8.57
nvidia-cuda-nvcc-cu12==12.8.61
nvidia-cuda-runtime-cu12==12.8.57
nvidia-cudnn-cu12==9.7.1.26
nvidia-cufft-cu12==11.3.3.41
nvidia-cusolver-cu12==11.7.2.55
nvidia-cusparse-cu12==12.5.7.53
nvidia-nccl-cu12==2.25.1
nvidia-nvjitlink-cu12==12.8.61
opt_einsum==3.4.0
pillow==11.1.0
pyparsing==3.2.1
scipy==1.15.2
setuptools==75.8.0
tqdm==4.67.1
I’m guessing this is something to do with how celerite2 and JAX work together - thought I would check here if anyone has any suggestions on a fix?