astenuz
September 11, 2020, 10:01pm
1
Hi! I was wondering whether there was a short reference on how to do model saving and loading. I saw two threads but it was not clear how to actually do the process of saving and loading the model. The model in question is a VAE adapted from here to work with pytorch ligthning
The colab notebook of is here
I’m interested in just restoring the model so it can be evaluated on new data. Maybe later I can figure out how to override lightning’s automatic checkpointing so it works with pyro. However, there seem to be a lot of issues in saving the state of the optimizers.
Thank you very much for your time!
1 Like
fritzo
September 18, 2020, 12:52am
2
I believe since your entire model is an nn.Module
, you can simply use torch.save()
and torch.load()
.
everli
September 18, 2020, 2:25am
3
I used torch.save() and torch.load() but failed to load. My code is as follows and can you find why?
class BayesianRegression(PyroModule):
def __init__(self, input_size=1, output_size=1):
super().__init__()
self.linear = PyroModule[nn.Linear](input_size, output_size)
ww=dist.Normal(0., 1.).expand([output_size, input_size]).to_event(2)
self.linear.weight = PyroSample(dist.Normal(0., 1.).expand([output_size, input_size]).to_event(2))
self.linear.bias = PyroSample(dist.Normal(0., 10.).expand([output_size]).to_event(1))
def forward(self, x, y=None):
sigma = pyro.sample("sigma", dist.Uniform(0., 10.))
mean = self.linear(x).squeeze(-1)
mean=mean.view(-1)
with pyro.plate("data", x.shape[0]):
obs = pyro.sample("obs", dist.Normal(mean, sigma), obs=y)
return mean
def save_checkpoint():
print('start-save')
torch.save(model.state_dict(),path+'/result_para/saved_params.save')
def load_checkpoint():
print("start-load")
state_dict=torch.load(path+'/result_para/saved_params.save')
model.load_state_dict(state_dict)
for iteration in range(n_iterations):
loss=0
for step, (batch_x, batch_y) in enumerate(loader_training): # for each training step
loss+=svi.step(batch_x,batch_y)
elbo.append( -loss)
##########################################################3
#print("Epoch ", epoch, " Loss ", total_epoch_loss_train)
if iteration % 100 == 0:
if iteration==1000:
# save_predictive(predictive)
save_checkpoint()
print('******************')
#analysis
guide.requires_grad_(False)
pyro.clear_param_store()
load_checkpoint()
predictive = Predictive(model, guide=guide,
num_samples=1000)
# predictive=load_predictive()
samples = predictive(trainX)
Errors:
RuntimeError: t() expects a tensor with <= 2 dimensions, but self is 3D
During handling of the above exception, another exception occurred:
RuntimeError: t() expects a tensor with <= 2 dimensions, but self is 3D
Trace Shapes:
Param Sites:
Sample Sites:
sigma dist |
value 1 |
linear.weight dist | 1 1
value 1 | 1 1
linear.bias dist | 1
value 1 | 1
fritzo
September 18, 2020, 2:03pm
4
Could you try to pyro.clear_param_store()
just before loading?
everli
September 19, 2020, 12:12am
5
Thanks for your reply! Actually, I have used pyro.clear_param_store()
before loading. Do you have methods to deal with it? Thank you!
fritzo
September 25, 2020, 2:11am
6
Maybe simplify to
torch.save(model, path)
pyro.clear_param_store()
model = torch.load(path)
I suspect PyroModule
isn’t playing well with the load_state_dict()
method.
everli
September 26, 2020, 4:32am
7
Thanks. But the problem still existed using your suggestion. I am not sure should we save the guide or other?
fritzo
September 28, 2020, 7:28pm
8
Hmm, I believe the predictive
object should have all needed state. What if you try to save that instead of the model
and guide
?
torch.save(predictive, path)
pyro.clear_param_store()
predictive = torch.load(path)
...evaluate on new data using predictive...