I was having an issue with pyro.get_param_store().load()
, in that after loading the param store, I verified (via print(pyro.param("decoder$$$fc21.bias").data.numpy())
) that the params are correct from before saving, but the model wasn’t using it correctly. Specifically, my VAE, modified from [the tutorial](http:// pyro.ai/examples/vae.html)] was reconstructing noise.
[Pull request 47](https:// add save/load to param store (with some accompanying changes) by martinjankowiak · Pull Request #47 · pyro-ppl/pyro · GitHub) and [22](https:// Add ability to save and load param store · Issue #22 · pyro-ppl/pyro · GitHub) discuss this issue, but they’re outdated—in the first link, pyro.sync_module(...)
isn’t implemented in the current 0.2.1 version nor in dev. .load()
was also brought up 27 days ago here where the answer is a link to the pyro documentation for pyro.get_param_store().load()
, which makes it seem like that’s all that needs to be done.
I eventually found the solution with a bit of searching the rest of the documentation, under pyro.module()
. One has to re-register the modules with the argument update_module_params=True
(it defaults to False).
Specifically,
vae = VAE(...)
# vae has an encoder and decoder module that was defined in the class defn
pyro.get_param_store().load('trained_models/vae_pretrained.save’)
# right now, if you use vae, it'll reconstruct noise. Params aren't updated.
pyro.module("decoder", vae.decoder, update_module_params=True)
pyro.module("encoder", vae.encoder, update_module_params=True)
# It should work now.
in order for vae
to properly use the loaded param store values.
Could the documentation for pyro.get_param_store().load()
be updated to reflect this additional re-registration step?
Thanks