author | Sasank Chilamkurthy
brief introduction
This tutorial mainly introduces how to use deep learning to realize migration learning . More detailed migration learning can be found in cs231n Course --https://cs231n.github.io/transfer-learning/
Practical application , Few people start from scratch , Train a convolutional neural network through random initialization , Because there are too few datasets . Usually , Everyone will choose a large data set ( such as ImageNet Data sets ,1000 In all, categories 120 Ten thousand pictures ) On the pre training model after training , Then the convolution neural network is initialized , Or to extract features .
Two main application scenarios of transfer learning :
-
Fine tune the network : Pre training model is used to initialize the network , Instead of random initialization , This approach can converge faster , At the same time, it can achieve better results .
-
As a feature extractor : This application will fix the weight parameters of the network layer except the last full connection layer , Then the final full connection layer will modify the output according to the number of data set categories , And initialize its weight randomly , Then train that layer .
The tutorial of this article , The model to import in the code is as follows :
Load data
This part of loading data will adopt torchvision
and torch.utils.data
Two modules .
The goal of this tutorial is to train a binary model , The categories are ants and bees , So the data set contains 120 Pictures of ants and bees training , Then each category also contains 75 Picture as verification set . That is to say, the data set manager only has 390 A picture , Less than a thousand pictures , It's a very small data set , If you train the model from scratch , It's hard to get good generalization ability . therefore , In this paper, we will use migration learning method to get better generalization ability for this dataset .
Get the data set and code for this article , You can reply in official account. “pytorch The migration study ” obtain .
The code to load the data is as follows :
Visualizations
First, visualize some training pictures , For better understanding of data enhancements . The code is as follows :
The pictures shown are as follows :

Training models
After loading the data , It's about starting a training model , Here are two things :
-
Develop strategies for learning rate
-
Save the best model
In the following code , Parameters scheduler
Is to use torch.optim.lr_scheduler
The initialization of the LR Policy object :
The above functions realize the training of the model , In a epoch It is divided into training and verification stages , The training phase naturally requires forward computing plus back propagation , And update the network layer parameters , But in the verification phase, only forward calculation is needed , Then record loss And the accuracy on the verification set .
In addition, you need to set the conditions for saving the model , Here is when the accuracy of each verification set is higher than the previous best accuracy , Save the model .
The prediction results of the visualization model
The following defines a function of the prediction result of the visual model , It is used to display the predicted category information of the picture and model to the picture :
Fine tune the network
This part is the core content of this transfer learning , The front is normal loading data 、 Code that defines the training process , This is how to fine tune the network , The code is as follows :
This step is due to the setup of load pre training model , So it will download the pre training after running resnet18
Network model file
Training and validation
Next, we begin to formally train the network model , The code is as follows , If the cpu, About need 15-25 Minute process , And if gpu, So it's fast , Basically in about a minute .
Training results :

Visual model prediction results :
The visualization results are as follows :

For feature extractors
It was just used to fine tune the network , The pre training model is used to initialize the parameters of the network layer , Next, we introduce the second use of transfer learning , As a feature extractor , That is to say, the weight parameters of the network layer in the part of the fixed pre training model , This part of the implementation code , As shown below , The parameters of the convolution layer should be fixed , Setting the requires_grad==False
, In this way, their gradients will not be calculated in the back propagation process , More memory to see https://pytorch.org/docs/notes/autograd.html#excluding-subgraphs-from-backward
Train again :
Training results :

Visualizing the prediction results of the network
