I’m trying to pickle a guide from an SVI numpyro model and running into a ConcretizationTypeError. Any ideas why this is happening? If guides cant be pickled, is there a right way to save them.
I’m using AutoLowRankMultivariateNormal
with open('guide.pickle', 'wb') as handle:
pickle.dump(model.guide, handle)
here’s the traceback
---------------------------------------------------------------------------
ConcretizationTypeError Traceback (most recent call last)
Cell In[27], line 2
1 with open('guide.pickle', 'wb') as handle:
----> 2 pickle.dump(model.guide, handle)
File ~/.pyenv/versions/3.8.13/envs/demand_dev38/lib/python3.8/site-packages/jax/_src/core.py:676, in Tracer.__reduce__(self)
675 def __reduce__(self):
--> 676 raise ConcretizationTypeError(
677 self, ("The error occurred in the __reduce__ method, which may "
678 "indicate an attempt to serialize/pickle a traced value."))
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape uint32[2].
The error occurred in the __reduce__ method, which may indicate an attempt to serialize/pickle a traced value.
This DynamicJaxprTracer was created on line /Users/kcaron/.pyenv/versions/3.8.13/envs/demand_dev38/lib/python3.8/site-packages/numpyro/handlers.py:729 (process_message)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError
Not sure if this is something obvious, or if more details about the model are needed let me know!
numpyro version: 0.12.1
jax version: 0.4.13
python 3.9.7