How to save the best model and guide during validation?

I created a deterministic convolutional neural network for classification, and then lifted it to a probabilistic network using pyro.random_module(). I further tuned the learning rate as a hyper parameter during SVI optimization. While looping over SVI, I sampled the random network many times, e.g., sampled_models = [guide(None, None) for _ in range(num_model_samples)], to get many instances and evaluated the performance on a validation data set – I wanted to keep the best probabilistic network. What can I do to save the best model and guide?

In Pytorch, we can save the best network (call it myNet) by

copy.deepcopy(myNet.state_dict())
torch.save(best_myNet_wts, savePath) 
myNet.load_state_dict(best_myNet_wts) 

What are the counterparts of deepcopy, save, and load_state_dict in Pyro? Thanks.

2 Likes

you can save and load the parameters of the ParamStore. you can also use torch.save as you would in pytorch:

output = {
'guide': guide,
'state_dict': my_nn.state_dict(),
'params': params
...
}
torch.save(output, outfile)

@jpchen, that works!

I saved the guide corresponding to the best predication accuracy (about 98%) during validation, and called it guide_best. Then I generated sample neural networks from guide_best, but these networks produced very low accuracy (40%) on test data. Did I miss anything? My code is:

num_model_samples = 50
sampled_models = [guide_best(None, None) for _ in range(num_model_samples)]
yhats = [model(x).data for model in sampled_models]

did you load the params? you need to load the saved params into the param store either by using load() or calling pyro.param on each param.

I also face some problems to save and load my SVI. Can someone help me solve it? Thank you!
My code is:

def save_checkpoint():
    print('start-save')
    output={'state_dict':model.state_dict(),
            'guide':guide,
            'params':pyro.get_param_store()
            }
    torch.save(output,path+'/result/saved_params.save')
    # torch.save(model.state_dict(),path+'/result/saved_params.save')

def load_checkpoint():
    print("start-load")
    output=torch.load(path+'/result/saved_params.save')
    model.load_state_dict(output['state_dict'])
    guide=output['guide']
    pyro.get_param_store().load(output['params'])

load_checkpoint()
predictive = Predictive(model, guide=guide, num_samples=1000,
return_sites=(“obs”,
))
cali_samples = predictive(trainX)

Warning:
UserWarning: Couldn’t retrieve source code for container of type AutoDiagonalNormal. It won’t be checked for correctness upon loading.
Error:
TypeError: expected str, bytes or os.PathLike object, not ParamStoreDict

Can you give a example?

2 Likes

I figured it out after months of searching. They should put this explicitly in their documentation!

To save:

torch.save({"model" : model.state_dict(), "guide" : guide}, "mymodel.pt")
pyro.get_param_store().save("mymodelparams.pt")

To load:

saved_model_dict = torch.load("mymodel.pt")
model.load_state_dict(saved_model_dict['model'])
guide = saved_model_dict['guide']
pyro.get_param_store().load("mymodelparams.pt")
11 Likes