I was having an issue with
pyro.get_param_store().load(), in that after loading the param store, I verified (via
print(pyro.param("decoder$$$fc21.bias").data.numpy())) that the params are correct from before saving, but the model wasn't using it correctly. Specifically, my VAE, modified from [the tutorial](http:// pyro.ai/examples/vae.html)] was reconstructing noise.
[Pull request 47](https:// github.com/uber/pyro/pull/47) and (https:// github.com/uber/pyro/issues/22) discuss this issue, but they're outdated---in the first link,
pyro.sync_module(...) isn't implemented in the current 0.2.1 version nor in dev.
.load() was also brought up 27 days ago here where the answer is a link to the pyro documentation for
pyro.get_param_store().load(), which makes it seem like that's all that needs to be done.
I eventually found the solution with a bit of searching the rest of the documentation, under
pyro.module(). One has to re-register the modules with the argument
update_module_params=True (it defaults to False).
vae = VAE(...)
# vae has an encoder and decoder module that was defined in the class defn
# right now, if you use vae, it'll reconstruct noise. Params aren't updated.
pyro.module("decoder", vae.decoder, update_module_params=True)
pyro.module("encoder", vae.encoder, update_module_params=True)
# It should work now.
in order for
vae to properly use the loaded param store values.
Could the documentation for
pyro.get_param_store().load() be updated to reflect this additional re-registration step?