Save model and load for evaluation

Hi! I was wondering whether there was a short reference on how to do model saving and loading. I saw two threads but it was not clear how to actually do the process of saving and loading the model. The model in question is a VAE adapted from here to work with pytorch ligthning

The colab notebook of is here

I’m interested in just restoring the model so it can be evaluated on new data. Maybe later I can figure out how to override lightning’s automatic checkpointing so it works with pyro. However, there seem to be a lot of issues in saving the state of the optimizers.

Thank you very much for your time!

1 Like

I believe since your entire model is an nn.Module, you can simply use torch.save() and torch.load().

I used torch.save() and torch.load() but failed to load. My code is as follows and can you find why?

class BayesianRegression(PyroModule):
def __init__(self, input_size=1, output_size=1):
    super().__init__()
    self.linear = PyroModule[nn.Linear](input_size, output_size)
    ww=dist.Normal(0., 1.).expand([output_size, input_size]).to_event(2)
    self.linear.weight = PyroSample(dist.Normal(0., 1.).expand([output_size, input_size]).to_event(2))
    self.linear.bias = PyroSample(dist.Normal(0., 10.).expand([output_size]).to_event(1))

def forward(self, x, y=None):         
    sigma = pyro.sample("sigma", dist.Uniform(0., 10.))
    mean = self.linear(x).squeeze(-1)
    mean=mean.view(-1)
    with pyro.plate("data", x.shape[0]):
        obs = pyro.sample("obs", dist.Normal(mean, sigma), obs=y)
    return mean
def save_checkpoint():
    print('start-save')

    torch.save(model.state_dict(),path+'/result_para/saved_params.save')

def load_checkpoint():
    print("start-load")
    state_dict=torch.load(path+'/result_para/saved_params.save')
    model.load_state_dict(state_dict)
for iteration in range(n_iterations):
    loss=0
    for step, (batch_x, batch_y) in enumerate(loader_training):  # for each training step
        loss+=svi.step(batch_x,batch_y)
    elbo.append( -loss)

   ##########################################################3
    #print("Epoch ", epoch, " Loss ", total_epoch_loss_train)
    if iteration % 100 == 0:
        if iteration==1000:
            # save_predictive(predictive)
            save_checkpoint()
            print('******************')
#analysis
guide.requires_grad_(False)
pyro.clear_param_store()
load_checkpoint()

predictive = Predictive(model, guide=guide,
                        num_samples=1000)


# predictive=load_predictive()
samples = predictive(trainX)

Errors:

RuntimeError: t() expects a tensor with <= 2 dimensions, but self is 3D

During handling of the above exception, another exception occurred:

RuntimeError: t() expects a tensor with <= 2 dimensions, but self is 3D
Trace Shapes:
Param Sites:
Sample Sites:
sigma dist |
value 1 |
linear.weight dist | 1 1
value 1 | 1 1
linear.bias dist | 1
value 1 | 1

Could you try to pyro.clear_param_store() just before loading?

Thanks for your reply! Actually, I have used pyro.clear_param_store() before loading. Do you have methods to deal with it? Thank you!

Maybe simplify to

torch.save(model, path)

pyro.clear_param_store()
model = torch.load(path)

I suspect PyroModule isn’t playing well with the load_state_dict() method.

Thanks. But the problem still existed using your suggestion. I am not sure should we save the guide or other?