Saving Gaussian Process parameters from previous session

I am implementing the example found here: https://pyro.ai/examples/dkl.html. I am having trouble saving and loading the model.

After I trained the gpmodule, I used torch.save(gpmodule.state_dict(), "model.p") to save the parameters for another session. However, when I enter another session, initialize a gpmodule object, and use gpmodule.load_state_dict() to load, the model performs as if it was randomly initialized.

Am I saving the model parameters incorrectly?

Hi @gtorres, could you try save and load methods of ParamStoreDict (to get param store, you can do pyro.get_param_store()). I think this is the approach of pyro.nn.Module, which GP depends on.

Thanks for the tip @fehiepsi ! I tried saving the parameters using pyro.get_param_store().save('test.pl') and then I opened another session and tried:

pyro.get_param_store().load('test.pl')
pyro.module('module', nn, update_module_params=True)

But this returned an error:

File "/home/gtorres/anaconda3/envs/bayes/lib/python3.7/site-packages/pyro/primitives.py", line 340, in module
    assert hasattr(nn_module, "parameters"), "module has no parameters"
AssertionError: module has no parameters

Sorry, I can confirm that using save/load from param_store does not update torch.nn.Parameter attributes (only PyroParam attributes). But your approach

torch.save(gpmodule.state_dict(), 'abc/sgp.pt')
...
gpmodule.load_state_dict(torch.load('abc/sgp.pt'))

works as expected. Could you check it again?

Another way is to use

torch.save(gpmodule, 'abc/gpm.pt')
...
gpmodule = torch.load('abc/gpm.pt')

I got it to work for the MNIST data in the example (using torch.save and load_state_dict as you suggested). However, when I try the same thing with a different dataset (larger and more classes, but the architecture is identical otherwise), for some reason it doesn’t load correctly. It’s very bizarre, since there is no difference in the scripts; just the input data. I think this is an issue on my end that I need to investigate. Thank you!