I have a high dimension multi-level regression model running with MCMC in Python 3.10, Numpyro 0.10.1 to 0.18.0, Jax 0.3 to 0.5.0. The data size is 10k-30k. The average processing time is around 10 minutes. After upgrading to Python 3.12, the processing time takes over 3-4 hours. The code doesn’t change at all except the package upgrades. I use MCMC with NUTS sampler with the same config. MCMC compiling takes extra long time. I have 255 steps previously its 1.44it/s and now it’s taking 1ms/s. Does anyone know why this happens? Is there anything I could do to improve the speed? SVI is not an option for us since the results we get from SVI is not good.
parameter for nuts sampler:
MAX_DEPTH_TREE= 4
mcmc:
NUM_SAMPLES = 1000, NUM_WARMUP = 500, NUM_CHAINS = 1
I suspect some numerical issues are happening. Could you change random seed or use float64?
Thanks for the quick response! We’ve changed the input data to float64 to fit the MCMC. It doesn’t improve speed that much. I used jrandom.key(123) as the random seeds. In the previous version, we used jrandom.PRNGKey(123). Could this be the potential issue?
Thanks! I’ve tried the configuration:
os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false")
os.environ.setdefault("XLA_PYTHON_CLIENT_MEM_FRACTION", "0.8")
# Compilation
optimizationsxla_flags = [
"--xla_cpu_multi_thread_eigen=false",
"--xla_force_host_platform_device_count=1",
"--xla_cpu_enable_fast_math=true",
"--xla_cpu_use_eigen_gemm=false"]
os.environ.setdefault("XLA_FLAGS", " ".join(xla_flags))
# Python 3.12 specific settings
os.environ.setdefault("JAX_ENABLE_X64", "true")
os.environ.setdefault("JAX_PLATFORMS", "cpu")
os.environ['XLA_FLAGS'] = '--xla_cpu_use_thunk_runtime=false'
with those settings, we still can’t bring back the runtime to the original speed. updates on the reason why it get slow: 2.17it/s for 255 steps in python 3.10, and 3.14s/it for 255 steps in python 3.12. Do you have insights on why this happens?
I don’t have further ideas. If you have some reproducible code (try to simplify the model), we can try to identify which operator causes the regression.
Our model structure is similiar to this example, but with higher dimension. I also ran this example in python 3.10 with older version of numpyro and python 3.12 with newer version numpyro. The python 3.10 version sample took 1minute 15s while the 3.12 version sample took 2 minute 11s. I wonder if the version is not best compatible with the 3.12 yet.
Could you isolate the issue by checking if this is due to newer version of numpyro, or newer version of jax, or newer version of python?
Thanks!! The older version of Numpyro and Jax is not compatible with 3.12. I tried JAX 0.5.0 and numpyro 0.18.0 with Python 3.10. The warmup takes long time: now its 3.68s/it, where previously we saw 1.44it/s. Do you think it’s because of the suboptimal version of packages we used?