Hello,
I am trying to save the model artifacts of my SVI session so I can run an inference process later on with a different script. I’m struggling to save the guide (as this is the object used in the inference process).
My first approach was to try use pickle to save the guide, but it does not support saving some JAX components. Dill, JSON, or CloudPickle do not seem to work either:
optimizer = Adam(constants.LEARNING_RATE)
hbd_guide = AutoGuide(hbd_model)
svi = SVI(hbd_model, hbd_guide, optimizer, loss=Trace_ELBO())
hbd_results = svi.run(
random.PRNGKey(constants.RNG_SEED),
n_epochs,
data,
train=True,
stable_update=True,
)
hbd_results_json = {}
# Losses
hbd_results_json["losses"] = hbd_results.losses.to_py().tolist()
# Params
hbd_results_json["params"] = {}
for key in hbd_results.params.keys():
hbd_results_json["params"][key] = hbd_results.params[key].to_py().tolist()
# Model
hbd_results_json["hbd_model"] = hbd_model
hbd_results_json["hbd_guide"] = hbd_guide
with open(hbd_results_json_path_test, 'wb') as handle:
pickle.dump(hbd_results_json, handle, protocol=pickle.HIGHEST_PROTOCOL)
I also have tried to save the model (there are no issues to save the model) and then reload the guide based on the model but that does not work either. It seems that during the training, the guide is modified to include a plates field, which I do not get it just by running:
with open(hbd_results_json_path_test, 'rb') as f:
hbd_results_load = pickle.load(f)
for key in hbd_results_load["params"].keys():
hbd_results_load["params"][key] = jnp.array(hbd_results_load["params"][key])
## Get guide_test
guide_test = AutoNormal(hbd_results_load["hbd_model"])
## Get posterior samples of latent variables
hbd_posterior = Predictive(
guide_test, params=hbd_results_load["params"], num_samples=n_samples #hbd_guide
)
hbd_samples = hbd_posterior(random.PRNGKey(constants.RNG_SEED))
## Generate model fit -- train sample & test sample
hbd_predict = Predictive(
model=hbd_model,
posterior_samples=hbd_samples,
guide=hbd_guide,
params=hbd_results_load["params"],
num_samples=n_samples,
)
This outputs an error when it tries to run the hbd_posterior function:
TypeError: hbd_model() missing 1 required positional arguments: 'data'
I have looked for information on how we can separate training from inference using SVI, but there is no much documentation on how to do this on numpyro. All examples I’ve seen have training and inferencing in the same script so there is no need to save the model. Saving the model after training for a later inference process should not be such headache, so it might be I’m just using the wrong tools/code.
Has anyone encounter this issue?
Thanks in advance!