author : Jenny Caywood
The original title is | SAVING AND LOADING MODELS
author | Matthew Inkawhich
brief introduction
This article mainly introduces how to load and save PyTorch Model of . There are three main core functions :
-
torch.save
: Save the serialized object to the hard disk . It takes advantage of it. Python Ofpickle
To implement serialization . Model 、 Tensors and dictionaries can be saved with this function ; -
torch.load
: usepickle
Load the deserialized object from storage . -
torch.nn.Module.load_state_dict
: Using a deserializedstate_dict
Load a model parameter Dictionary .
The main contents of this paper are as follows :
-
What is a state Dictionary (state_dict)?
-
Load and save models when forecasting
-
Load and save a common checkpoint (Checkpoint)
-
Save multiple models in the same file
-
The parameters of another model are used to preheat the model (Warmstaring Model)
-
Save and load models under different devices
1. What is a state Dictionary (state_dict)
PyTorch in , A model (torch.nn.Module
) The learnable parameters of ( That is, the weight and offset value ) Are included in the model parameters (model.parameters()
) Medium , A state dictionary is a simple Python Dictionary , The key value pair is each network layer and its corresponding parameter tensor . The state Dictionary of the model contains only the network layer with learnable parameters ( For example, the convolution layer 、 Full connectivity layer, etc ) And registered cache (batchnorm
Of running_mean
). Optimizer object (torch.optim
) There is also a state Dictionary , Contains information about the status of the optimizer and the super parameters used .
Because the status dictionary is also Python Dictionary , So right. PyTorch Preservation of models and optimizers 、 to update 、 Replace 、 Recovery and other operations are easy to implement .
Here is a simple example , Examples come from :https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py
The above code simply defines a 5 Layer of CNN, Then print the parameters of the model and the optimizer parameters respectively .
Output results :
2. Load and save models when forecasting
load / Save status Dictionary ( The recommendation )
Saved code :
Loaded code :
When you need to save a model for the forecast , Just save the learnable parameters of the training model . use torch.save()
To save the state Dictionary of the model, it is more convenient to load the model , That's why this is recommended .
Usually I use .pt
perhaps .pth
Suffix to save the model .
remember
-
Before making a prediction , Must call
model.eval()
Methods todropout
andbatch normalization
The layer is set to validate the model . otherwise , Only inconsistent prediction results will be generated . -
load_state_dict()
Method must pass in a dictionary object , Not the save path of the object , That is, the dictionary object must be deserialized first , Then call the method , It is also the first example to usetorch.load()
, Not directlymodel.load_state_dict(PATH)
load / Save the entire model
preservation :
load :
Both saving and loading models use very intuitive syntax and can be implemented in only a few lines of code . This approach to saving the model will be to use Python Of pickle
Module to save the whole model , The disadvantage of this method is that the serialized data belongs to a specific class and a specified dictionary structure , The reason is that pickle
Model category not saved , Instead, save a file path containing the class , therefore , When in other projects or in refactors
Errors may occur after adoption .
3. Load and save a common checkpoint (Checkpoint)
Saved sample code :
Loaded sample code :
When saving a common checkpoint (checkpoint) when , Whether it's for continuing training or predicting , Need to save more information , not only state_dict
, For example, the optimizer's state_dict
It's also very important , It contains parameters and cache information that need to be updated during model training , Information that can also be saved includes epoch
, That is, the batch of interrupted training , The last training loss, additional torch.nn.Embedding
Layers and so on .
The above saving code introduces how to save so many kinds of information , Organize by using a dictionary , Then continue to call torch.save
Method , Generally, the suffix of the saved file is .tar
.
The loading code is also shown in the above code , First, you need to initialize the model and optimizer , Then call... Respectively when loading the model torch.load
Load the corresponding state_dict
. Then get the corresponding value through different keys .
After the loading , Follow the next steps , call model.eval()
Used to predict ,model.train()
For recovery training .
4. Save multiple models in the same file
Save the sample code of the model :
Load the sample code of the model :
When we want to save a network model that contains multiple networks torch.nn.Modules
When , such as GAN、 A serialization model , Or fusion of multiple models , The implementation method is actually the same as The practice of saving a common checkpoint is the same , A dictionary is also used to keep the model state_dict
And the corresponding optimizer state_dict
. besides , You can also continue to save other same information .
The sample code for loading the model is shown above , It's the same as loading a common checkpoint , You also need to initialize the corresponding model and optimizer first . Again , The saved model file is usually in .tar
Suffixed name .
5. The parameters of another model are used to preheat the model (Warmstaring Model)
Save the sample code of the model :
Load the sample code of the model :
In the previous migration learning tutorial, it was also introduced that you can fine tune through the pre training model , Speed up model training and improve model accuracy .
This method usually loads part of the network parameters of the pre training model as the initialization parameters of the model , Then we can speed up the convergence of the model .
The code for loading the pre training model is shown above , Where parameters are set strict=False
Indicates ignoring mismatched network layer parameters , Because usually we don't completely adopt the same network as the pre training model , Usually, the parameters of the output layer will be different .
Of course , If you want to load parameters with different parameter names , You can modify the parameter name corresponding to the loaded model , In this way, if the parameter names match, you can successfully load .
6. Save and load models under different devices
stay GPU Save model on , stay CPU Load model on
Save the sample code of the model :
Load the sample code of the model :
stay CPU Load on GPU On the training model , Must be called on torch.load()
When , Set parameters map_location
, The equipment specified is torch.device('cpu')
, This will remap all tensors to CPU On .
stay GPU Save model on , stay GPU Load model on
Save the sample code of the model :
Load the sample code of the model :
stay GPU Train and load models on , call torch.load()
After loading the model , Also need to use model.to(torch.device('cuda'))
, Call the model to GPU On , And the subsequent input tensors need to ensure that they are in GPU The use of , That is, it is also necessary to adopt my_tensor.to(device)
.
stay CPU Save on , stay GPU Load model on
Save the sample code of the model :
Load the sample code of the model :
This time it is CPU Train the model , But in GPU Load the model on using , Then you need to pass the parameter map_location
Designated equipment . Then continue to remember to call model.to(torch.device('cuda'))
.
preservation torch.nn.DataParallel Model
Save the sample code of the model :
torch.nn.DataParallel
Is used to implement multiple GPU Parallel operation , When saving the model , Is to use model.module.state_dict()
.
The code that loads the model is the same , use torch.load()
, And can be placed in the specified GPU On the video card .
