Compilation failing after sampling finishes

Hello,

I am getting the following error after my chain finishes sampling but before MCMC.run() has finished compiling the results:

    mcmc.run(rng_key_, pedict, injdict, constants['nObs'], constants['obs_time'], constants['total_inj'], mass_models, mag_model, tilt_model, z_model, parsargs.mmin, parsargs.mmax, nspline_dict, param_names, rngkey=catkey)
  File "/home/jaxeng/.conda/envs/gwinferno_jaxns/lib/python3.12/site-packages/numpyro/infer/mcmc.py", line 682, in run
    states_flat, last_state = partial_map_fn(map_args)
                              ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jaxeng/.conda/envs/gwinferno_jaxns/lib/python3.12/site-packages/numpyro/infer/mcmc.py", line 500, in _single_chain_mcmc
    states[self._sample_field] = postprocess_fn(states[self._sample_field])
                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Can't pack device memory arguments array of size 1255 which is larger than the maximum supported size of 1024

I only get this error with large warmup and sample sizes, ~100k. I think it started occurring when I started saving a large number of deterministic variables during sampling. I have done this in the past with older versions of numpyro/jax without issue, same sample sizes, deterministic variables, GPU, etc.

I’m using the NUTS kernel and an 80GB Nvidia A100 GPU.

Any ideas of what is going on would be appreciated!

sounds like a jax issue. you can try to isolate and reproduce the issue for them to diagnose/fix or rollback to an earlier combo of numpyro/jax