Jax ConcretizationTypeError when saving mcmc using dill after running with chain_method='parallel'

Hi Friends!

I was wondering if someone had a fix for this. I’m running mcmc with more than one chain:

kernel = DiscreteHMCGibbs(
    NUTS(bymodeln.both_hems_multi_fit_model, dense_mass=True), modified=True
)
mcmc = MCMC(kernel, num_warmup=2, num_samples=1, progress_bar=True, num_chains=2, chain_method='parallel')

mcmc.run(
    random.PRNGKey(0),
    ref_lat=ref_lat,
    ref_time=ref_time,
    ref_area=ref_area,
    fit_lat=fit_lat,
    fit_time=fit_time,
    n_obs_days=n_obs_days,
)

and then try to save it to a file using dill:

import dill
MCMC_OUT_FOLDER = '/d0/amunozj/bfly400/output/mcmc-nuts/'

output_file = f'{MCMC_OUT_FOLDER}/mcmc_fit_c{fit_cycles[0]}_unraveled_selection_chains.pkl'
with open(output_file, 'wb') as f:  # open a text file
    dill.dump(mcmc, f) # serialize the mcmc run
f.close()

And I get the error:

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int64[32]

The error goes away if I use chain_method='sequential'

Has anybody run into an issue like this? any idea how to fix it?