I am implementing the example found here: Example: Deep Kernel Learning — Pyro Tutorials 1.8.4 documentation. I am having trouble saving and loading the model.
After I trained the gpmodule
, I used torch.save(gpmodule.state_dict(), "model.p")
to save the parameters for another session. However, when I enter another session, initialize a gpmodule
object, and use gpmodule.load_state_dict()
to load, the model performs as if it was randomly initialized.
Am I saving the model parameters incorrectly?
Hi @gtorres, could you try save and load methods of ParamStoreDict (to get param store, you can do pyro.get_param_store()
). I think this is the approach of pyro.nn.Module
, which GP depends on.
Thanks for the tip @fehiepsi ! I tried saving the parameters using pyro.get_param_store().save('test.pl')
and then I opened another session and tried:
pyro.get_param_store().load('test.pl')
pyro.module('module', nn, update_module_params=True)
But this returned an error:
File "/home/gtorres/anaconda3/envs/bayes/lib/python3.7/site-packages/pyro/primitives.py", line 340, in module
assert hasattr(nn_module, "parameters"), "module has no parameters"
AssertionError: module has no parameters
Sorry, I can confirm that using save/load
from param_store does not update torch.nn.Parameter
attributes (only PyroParam attributes). But your approach
torch.save(gpmodule.state_dict(), 'abc/sgp.pt')
...
gpmodule.load_state_dict(torch.load('abc/sgp.pt'))
works as expected. Could you check it again?
Another way is to use
torch.save(gpmodule, 'abc/gpm.pt')
...
gpmodule = torch.load('abc/gpm.pt')
I got it to work for the MNIST data in the example (using torch.save and load_state_dict as you suggested). However, when I try the same thing with a different dataset (larger and more classes, but the architecture is identical otherwise), for some reason it doesn’t load correctly. It’s very bizarre, since there is no difference in the scripts; just the input data. I think this is an issue on my end that I need to investigate. Thank you!