Pytorch series | quick start migration learning

lc013 2021-09-15 08:32:04

 

PyTorch series | Quick start migration learning _ initialization

 

author  | Sasank Chilamkurthy

brief introduction

This tutorial mainly introduces how to use deep learning to realize migration learning . More detailed migration learning can be found in cs231n Course --https://cs231n.github.io/transfer-learning/

Practical application , Few people start from scratch , Train a convolutional neural network through random initialization , Because there are too few datasets . Usually , Everyone will choose a large data set ( such as ImageNet Data sets ,1000 In all, categories 120 Ten thousand pictures ) On the pre training model after training , Then the convolution neural network is initialized , Or to extract features .

Two main application scenarios of transfer learning :

  • Fine tune the network : Pre training model is used to initialize the network , Instead of random initialization , This approach can converge faster , At the same time, it can achieve better results .

  • As a feature extractor : This application will fix the weight parameters of the network layer except the last full connection layer , Then the final full connection layer will modify the output according to the number of data set categories , And initialize its weight randomly , Then train that layer .

The tutorial of this article , The model to import in the code is as follows :

# License: BSD
# Author: Sasank Chilamkurthy
from __future__ import print_function, division
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
plt.ion() # interactive mode

  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.
  • 17.
  • 18.

Load data

This part of loading data will adopt  torchvision  and  torch.utils.data  Two modules .

The goal of this tutorial is to train a binary model , The categories are ants and bees , So the data set contains 120 Pictures of ants and bees training , Then each category also contains 75 Picture as verification set . That is to say, the data set manager only has 390 A picture , Less than a thousand pictures , It's a very small data set , If you train the model from scratch , It's hard to get good generalization ability . therefore , In this paper, we will use migration learning method to get better generalization ability for this dataset .

Get the data set and code for this article , You can reply in official account. “pytorch The migration study ” obtain .

The code to load the data is as follows :

# Data enhancement method , Training sessions for random cropping and horizontal flipping , Then normalize
# The verification set is just clipping and normalization , No data enhancement
data_transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
# The folder where the dataset is located
data_dir = 'data/hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
data_transforms[x])
for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
shuffle=True, num_workers=4)
for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.
  • 17.
  • 18.
  • 19.
  • 20.
  • 21.
  • 22.
  • 23.
  • 24.
  • 25.
  • 26.
  • 27.
  • 28.

Visualizations

First, visualize some training pictures , For better understanding of data enhancements . The code is as follows :

# The function shown in the picture
def imshow(inp, title=None):
"""Imshow for Tensor."""
# Reverse operation , from tensor Change back numpy Array needs to convert channel position
inp = inp.numpy().transpose((1, 2, 0))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
# From normalized back to original image
inp = std * inp + mean
inp = np.clip(inp, 0, 1)
plt.imshow(inp)
if title is not None:
plt.title(title)
plt.pause(0.001) # pause a bit so that plots are updated
# Get one batch Training data
inputs, classes = next(iter(dataloaders['train']))
# Make a grid from batch
out = torchvision.utils.make_grid(inputs)
imshow(out, title=[class_names[x] for x in classes])

  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.
  • 17.
  • 18.
  • 19.
  • 20.
  • 21.
  • 22.
  • 23.

The pictures shown are as follows :

PyTorch series | Quick start migration learning _ The migration study _02

Training models

After loading the data , It's about starting a training model , Here are two things :

  • Develop strategies for learning rate

  • Save the best model

In the following code , Parameters  scheduler  Is to use  torch.optim.lr_scheduler  The initialization of the LR Policy object :

# Function of training model
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
since = time.time()
best_model_wts = copy.deepcopy(model.state_dict())
best_acc = 0.0
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-' * 10)
# Every epoch They are divided into training stage and verification stage
for phase in ['train', 'val']:
# Pay attention to the training and verification stages , It needs to be done separately model Set up
if phase == 'train':
scheduler.step()
model.train() # Set model to training mode
else:
model.eval() # Set model to evaluate mode
running_loss = 0.0
running_corrects = 0
# Iterate over data.
for inputs, labels in dataloaders[phase]:
inputs = inputs.to(device)
labels = labels.to(device)
# Clear the gradient of the parameter
optimizer.zero_grad()
# It's only the training phase that tracks history
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
# In the training stage, back propagation and parameter updating are carried out
if phase == 'train':
loss.backward()
optimizer.step()
# Record loss and Accuracy rate
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
epoch_loss = running_loss / dataset_sizes[phase]
epoch_acc = running_corrects.double() / dataset_sizes[phase]
print('{} Loss: {:.4f} Acc: {:.4f}'.format(
phase, epoch_loss, epoch_acc))
# deep copy the model
if phase == 'val' and epoch_acc > best_acc:
best_acc = epoch_acc
best_model_wts = copy.deepcopy(model.state_dict())
print()
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(
time_elapsed // 60, time_elapsed % 60))
print('Best val Acc: {:4f}'.format(best_acc))
# Load the best model parameters
model.load_state_dict(best_model_wts)
return model

  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.
  • 17.
  • 18.
  • 19.
  • 20.
  • 21.
  • 22.
  • 23.
  • 24.
  • 25.
  • 26.
  • 27.
  • 28.
  • 29.
  • 30.
  • 31.
  • 32.
  • 33.
  • 34.
  • 35.
  • 36.
  • 37.
  • 38.
  • 39.
  • 40.
  • 41.
  • 42.
  • 43.
  • 44.
  • 45.
  • 46.
  • 47.
  • 48.
  • 49.
  • 50.
  • 51.
  • 52.
  • 53.
  • 54.
  • 55.
  • 56.
  • 57.
  • 58.
  • 59.
  • 60.
  • 61.
  • 62.
  • 63.
  • 64.
  • 65.
  • 66.
  • 67.

The above functions realize the training of the model , In a epoch It is divided into training and verification stages , The training phase naturally requires forward computing plus back propagation , And update the network layer parameters , But in the verification phase, only forward calculation is needed , Then record loss And the accuracy on the verification set .

In addition, you need to set the conditions for saving the model , Here is when the accuracy of each verification set is higher than the previous best accuracy , Save the model .

The prediction results of the visualization model

The following defines a function of the prediction result of the visual model , It is used to display the predicted category information of the picture and model to the picture :

# Visual model prediction results , That is to show the predicted category information of the image and model , Default display 6 A picture
def visualize_model(model, num_images=6):
was_training = model.training
model.eval()
images_so_far = 0
fig = plt.figure()
with torch.no_grad():
for i, (inputs, labels) in enumerate(dataloaders['val']):
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
for j in range(inputs.size()[0]):
images_so_far += 1
ax = plt.subplot(num_images//2, 2, images_so_far)
ax.axis('off')
ax.set_title('predicted: {}'.format(class_names[preds[j]]))
imshow(inputs.cpu().data[j])
if images_so_far == num_images:
model.train(mode=was_training)
return
model.train(mode=was_training)

  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.
  • 17.
  • 18.
  • 19.
  • 20.
  • 21.
  • 22.
  • 23.
  • 24.
  • 25.
  • 26.

Fine tune the network

This part is the core content of this transfer learning , The front is normal loading data 、 Code that defines the training process , This is how to fine tune the network , The code is as follows :

# load resnet18 A network model , And set the load pre training model
model_ft = models.resnet18(pretrained=True)
num_ftrs = model_ft.fc.in_features
# Modify the output number of output layers , The data set category adopted this time is 2
model_ft.fc = nn.Linear(num_ftrs, 2)
model_ft = model_ft.to(device)
criterion = nn.CrossEntropyLoss()
# Update all network layer parameters
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
# Learning rate strategy , Every time 7 individual epochs multiply 0.1
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.

This step is due to the setup of load pre training model , So it will download the pre training after running  resnet18  Network model file

Training and validation

Next, we begin to formally train the network model , The code is as follows , If the cpu, About need 15-25 Minute process , And if gpu, So it's fast , Basically in about a minute .

model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,
num_epochs=25)

  • 1.
  • 2.

Training results :

PyTorch series | Quick start migration learning _ initialization _03

Visual model prediction results :

visualize_model(model_ft)

  • 1.

The visualization results are as follows :

PyTorch series | Quick start migration learning _ Load data _04

For feature extractors

It was just used to fine tune the network , The pre training model is used to initialize the parameters of the network layer , Next, we introduce the second use of transfer learning , As a feature extractor , That is to say, the weight parameters of the network layer in the part of the fixed pre training model , This part of the implementation code , As shown below , The parameters of the convolution layer should be fixed , Setting the  requires_grad==False , In this way, their gradients will not be calculated in the back propagation process , More memory to see https://pytorch.org/docs/notes/autograd.html#excluding-subgraphs-from-backward

model_conv = torchvision.models.resnet18(pretrained=True)
# Fix the weight parameter of the convolution layer
for param in model_conv.parameters():
param.requires_grad = False
# New network layer parameters default requires_grad=True
num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, 2)
model_conv = model_conv.to(device)
criterion = nn.CrossEntropyLoss()
# Only the parameters of the output layer are updated
optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)
# Learning rate strategy , Every time 7 individual epochs multiply 0.1
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)

  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.
  • 17.
  • 18.

Train again :

model_conv = train_model(model_conv, criterion, optimizer_conv,
exp_lr_scheduler, num_epochs=25)

  • 1.
  • 2.

Training results :

PyTorch series | Quick start migration learning _ Load data _05

Visualizing the prediction results of the network

visualize_model(model_conv)
plt.ioff()
plt.show()

  • 1.
  • 2.
  • 3.
  • 4.

 

 

PyTorch series | Quick start migration learning _ Data sets _06

 

Please bring the original link to reprint ,thank
Similar articles

2021-09-15