Pickling a guide function is failing

I’m trying to pickle a guide from an SVI numpyro model and running into a ConcretizationTypeError. Any ideas why this is happening? If guides cant be pickled, is there a right way to save them.

I’m using AutoLowRankMultivariateNormal

with open('guide.pickle', 'wb') as handle:
    pickle.dump(model.guide, handle)

here’s the traceback

---------------------------------------------------------------------------
ConcretizationTypeError                   Traceback (most recent call last)
Cell In[27], line 2
      1 with open('guide.pickle', 'wb') as handle:
----> 2     pickle.dump(model.guide, handle)

File ~/.pyenv/versions/3.8.13/envs/demand_dev38/lib/python3.8/site-packages/jax/_src/core.py:676, in Tracer.__reduce__(self)
    675 def __reduce__(self):
--> 676   raise ConcretizationTypeError(
    677     self, ("The error occurred in the __reduce__ method, which may "
    678            "indicate an attempt to serialize/pickle a traced value."))

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape uint32[2].
The error occurred in the __reduce__ method, which may indicate an attempt to serialize/pickle a traced value.
This DynamicJaxprTracer was created on line /Users/kcaron/.pyenv/versions/3.8.13/envs/demand_dev38/lib/python3.8/site-packages/numpyro/handlers.py:729 (process_message)

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

Not sure if this is something obvious, or if more details about the model are needed let me know!

numpyro version: 0.12.1
jax version: 0.4.13
python 3.9.7