Hello,
I am getting the following error after my chain finishes sampling but before MCMC.run()
has finished compiling the results:
mcmc.run(rng_key_, pedict, injdict, constants['nObs'], constants['obs_time'], constants['total_inj'], mass_models, mag_model, tilt_model, z_model, parsargs.mmin, parsargs.mmax, nspline_dict, param_names, rngkey=catkey)
File "/home/jaxeng/.conda/envs/gwinferno_jaxns/lib/python3.12/site-packages/numpyro/infer/mcmc.py", line 682, in run
states_flat, last_state = partial_map_fn(map_args)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jaxeng/.conda/envs/gwinferno_jaxns/lib/python3.12/site-packages/numpyro/infer/mcmc.py", line 500, in _single_chain_mcmc
states[self._sample_field] = postprocess_fn(states[self._sample_field])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Can't pack device memory arguments array of size 1255 which is larger than the maximum supported size of 1024
I only get this error with large warmup and sample sizes, ~100k. I think it started occurring when I started saving a large number of deterministic variables during sampling. I have done this in the past with older versions of numpyro/jax without issue, same sample sizes, deterministic variables, GPU, etc.
I’m using the NUTS kernel and an 80GB Nvidia A100 GPU.
Any ideas of what is going on would be appreciated!