Pytorch | saving and loading models tutorial

lc013 2021-09-15 08:32:23

 

PyTorch | Save and load model tutorial _ Sample code

author : Jenny Caywood 

The original title is  | SAVING AND LOADING MODELS

author  | Matthew Inkawhich

brief introduction

This article mainly introduces how to load and save PyTorch Model of . There are three main core functions :

  1. torch.save : Save the serialized object to the hard disk . It takes advantage of it. Python Of  pickle  To implement serialization . Model 、 Tensors and dictionaries can be saved with this function ;

  2. torch.load: use  pickle  Load the deserialized object from storage .

  3. torch.nn.Module.load_state_dict: Using a deserialized  state_dict Load a model parameter Dictionary .

The main contents of this paper are as follows :

  • What is a state Dictionary (state_dict)?

  • Load and save models when forecasting

  • Load and save a common checkpoint (Checkpoint)

  • Save multiple models in the same file

  • The parameters of another model are used to preheat the model (Warmstaring Model)

  • Save and load models under different devices

1. What is a state Dictionary (state_dict)

PyTorch in , A model (torch.nn.Module) The learnable parameters of ( That is, the weight and offset value ) Are included in the model parameters (model.parameters()) Medium , A state dictionary is a simple Python Dictionary , The key value pair is each network layer and its corresponding parameter tensor . The state Dictionary of the model contains only the network layer with learnable parameters ( For example, the convolution layer 、 Full connectivity layer, etc ) And registered cache (batchnorm Of  running_mean). Optimizer object (torch.optim) There is also a state Dictionary , Contains information about the status of the optimizer and the super parameters used .

Because the status dictionary is also Python Dictionary , So right. PyTorch Preservation of models and optimizers 、 to update 、 Replace 、 Recovery and other operations are easy to implement .

Here is a simple example , Examples come from :https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py

# Define model
class TheModelClass(nn.Module):
def __init__(self):
super(TheModelClass, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# Initialize model
model = TheModelClass()
# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
print(param_tensor, "\t", model.state_dict()[param_tensor].size())
# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
print(var_name, "\t", optimizer.state_dict()[var_name])

  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.
  • 17.
  • 18.
  • 19.
  • 20.
  • 21.
  • 22.
  • 23.
  • 24.
  • 25.
  • 26.
  • 27.
  • 28.
  • 29.
  • 30.
  • 31.
  • 32.
  • 33.
  • 34.
  • 35.

The above code simply defines a 5 Layer of CNN, Then print the parameters of the model and the optimizer parameters respectively .

Output results :

Model's state_dict:
conv1.weight torch.Size([6, 3, 5, 5])
conv1.bias torch.Size([6])
conv2.weight torch.Size([16, 6, 5, 5])
conv2.bias torch.Size([16])
fc1.weight torch.Size([120, 400])
fc1.bias torch.Size([120])
fc2.weight torch.Size([84, 120])
fc2.bias torch.Size([84])
fc3.weight torch.Size([10, 84])
fc3.bias torch.Size([10])
Optimizer's state_dict:
state {}
param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [4675713712, 4675713784, 4675714000, 4675714072, 4675714216, 4675714288, 4675714432, 4675714504, 4675714648, 4675714720]}]

  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.

2. Load and save models when forecasting

load / Save status Dictionary ( The recommendation )

Saved code :

torch.save(model.state_dict(), PATH)

  • 1.

Loaded code :

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

  • 1.
  • 2.
  • 3.

When you need to save a model for the forecast , Just save the learnable parameters of the training model . use  torch.save()  To save the state Dictionary of the model, it is more convenient to load the model , That's why this is recommended .

Usually I use  .pt  perhaps  .pth  Suffix to save the model .

remember

  1. Before making a prediction , Must call  model.eval()  Methods to  dropout  and  batch normalization  The layer is set to validate the model . otherwise , Only inconsistent prediction results will be generated .

  2. load_state_dict()  Method must pass in a dictionary object , Not the save path of the object , That is, the dictionary object must be deserialized first , Then call the method , It is also the first example to use  torch.load() , Not directly  model.load_state_dict(PATH)

load / Save the entire model

preservation

torch.save(model, PATH)

  • 1.

load

# Model class must be defined somewhere
model = torch.load(PATH)
model.eval()

  • 1.
  • 2.
  • 3.

Both saving and loading models use very intuitive syntax and can be implemented in only a few lines of code . This approach to saving the model will be to use Python Of  pickle  Module to save the whole model , The disadvantage of this method is that the serialized data belongs to a specific class and a specified dictionary structure , The reason is that  pickle  Model category not saved , Instead, save a file path containing the class , therefore , When in other projects or in  refactors  Errors may occur after adoption .

3. Load and save a common checkpoint (Checkpoint)

Saved sample code

torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
...
}, PATH)

  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.

Loaded sample code

model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.eval()
# - or -
model.train()

  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.

When saving a common checkpoint (checkpoint) when , Whether it's for continuing training or predicting , Need to save more information , not only  state_dict , For example, the optimizer's  state_dict  It's also very important , It contains parameters and cache information that need to be updated during model training , Information that can also be saved includes  epoch, That is, the batch of interrupted training , The last training loss, additional  torch.nn.Embedding  Layers and so on .

The above saving code introduces how to save so many kinds of information , Organize by using a dictionary , Then continue to call  torch.save  Method , Generally, the suffix of the saved file is  .tar .

The loading code is also shown in the above code , First, you need to initialize the model and optimizer , Then call... Respectively when loading the model  torch.load  Load the corresponding  state_dict . Then get the corresponding value through different keys .

After the loading , Follow the next steps , call  model.eval()  Used to predict ,model.train()  For recovery training .

4. Save multiple models in the same file

Save the sample code of the model

torch.save({
'modelA_state_dict': modelA.state_dict(),
'modelB_state_dict': modelB.state_dict(),
'optimizerA_state_dict': optimizerA.state_dict(),
'optimizerB_state_dict': optimizerB.state_dict(),
...
}, PATH)

  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.

Load the sample code of the model

modelA = TheModelAClass(*args, **kwargs)
modelB = TheModelBClass(*args, **kwargs)
optimizerA = TheOptimizerAClass(*args, **kwargs)
optimizerB = TheOptimizerBClass(*args, **kwargs)
checkpoint = torch.load(PATH)
modelA.load_state_dict(checkpoint['modelA_state_dict'])
modelB.load_state_dict(checkpoint['modelB_state_dict'])
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])
modelA.eval()
modelB.eval()
# - or -
modelA.train()
modelB.train()

  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.

When we want to save a network model that contains multiple networks  torch.nn.Modules  When , such as GAN、 A serialization model , Or fusion of multiple models , The implementation method is actually the same as The practice of saving a common checkpoint is the same , A dictionary is also used to keep the model  state_dict  And the corresponding optimizer  state_dict . besides , You can also continue to save other same information .

The sample code for loading the model is shown above , It's the same as loading a common checkpoint , You also need to initialize the corresponding model and optimizer first . Again , The saved model file is usually in  .tar  Suffixed name .

5. The parameters of another model are used to preheat the model (Warmstaring Model)

Save the sample code of the model

torch.save(modelA.state_dict(), PATH)

  • 1.

Load the sample code of the model

modelB = TheModelBClass(*args, **kwargs)
modelB.load_state_dict(torch.load(PATH), strict=False)

  • 1.
  • 2.

In the previous migration learning tutorial, it was also introduced that you can fine tune through the pre training model , Speed up model training and improve model accuracy .

This method usually loads part of the network parameters of the pre training model as the initialization parameters of the model , Then we can speed up the convergence of the model .

The code for loading the pre training model is shown above , Where parameters are set  strict=False  Indicates ignoring mismatched network layer parameters , Because usually we don't completely adopt the same network as the pre training model , Usually, the parameters of the output layer will be different .

Of course , If you want to load parameters with different parameter names , You can modify the parameter name corresponding to the loaded model , In this way, if the parameter names match, you can successfully load .

6. Save and load models under different devices

stay GPU Save model on , stay CPU Load model on

Save the sample code of the model

torch.save(model.state_dict(), PATH)

  • 1.

Load the sample code of the model

device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))

  • 1.
  • 2.
  • 3.

stay CPU Load on GPU On the training model , Must be called on  torch.load()  When , Set parameters  map_location , The equipment specified is  torch.device('cpu'), This will remap all tensors to CPU On .

stay GPU Save model on , stay GPU Load model on

Save the sample code of the model

torch.save(model.state_dict(), PATH)

  • 1.

Load the sample code of the model

device = torch.device('cuda')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH)
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model

  • 1.
  • 2.
  • 3.
  • 4.
  • 5.

stay GPU Train and load models on , call  torch.load()  After loading the model , Also need to use  model.to(torch.device('cuda')), Call the model to GPU On , And the subsequent input tensors need to ensure that they are in GPU The use of , That is, it is also necessary to adopt  my_tensor.to(device).

stay CPU Save on , stay GPU Load model on

Save the sample code of the model

torch.save(model.state_dict(), PATH)

  • 1.

Load the sample code of the model

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0")) # Choose whatever GPU device number you want
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model

  • 1.
  • 2.
  • 3.
  • 4.
  • 5.

This time it is CPU Train the model , But in GPU Load the model on using , Then you need to pass the parameter  map_location  Designated equipment . Then continue to remember to call  model.to(torch.device('cuda')).

preservation torch.nn.DataParallel Model

Save the sample code of the model

torch.save(model.module.state_dict(), PATH)

  • 1.

torch.nn.DataParallel  Is used to implement multiple GPU Parallel operation , When saving the model , Is to use  model.module.state_dict().

The code that loads the model is the same , use  torch.load() , And can be placed in the specified GPU On the video card .

PyTorch | Save and load model tutorial _ serialize _02
Please bring the original link to reprint ,thank
Similar articles

2021-09-15