Introduction to Finetuning a Pytorch Pretrained Model using an Arbitrary Number of Classes
In this post we are going to take a Pytorch model that has been pre-trained on Imagenet (with 100 classes) and fine tune it on a different dataset (CIFAR2) with a 2 classes.
This post is aimed at those who are familiar with Jupyter, Python and Deep Learning concepts in general eg Tensorflow or FastAI users. If you are not already familiar with Deep Learning you may find Practical Deep Learning for Coders by Jeremy Howard and Sylvain Gugger to be an excellent introduction to both the Fast AI library and to deep learning in general.
While not included in this example; the use of a tracking system, such as the open source MLFlow, to log results during training and to store the resulting model would be best practice and provides many benefits over the print out methods demonstrated here.
The training and dataset creation have been deliberately kept streamlined in order to show the concepts involved, in a real situation additional work would be needed to ensure reliability and robust functioning.
This code was tested on enviroment E037 using Jupyter Lab
import torch
import torchvision as tv
import numpy as np
from PIL import Image
from IPython.display import Image as IpyImg
import matplotlib.pyplot as plt
def get_device():
"""Set as GPU if available, else set CPU"""
if torch.cuda.is_available():
return torch.device('cuda')
else:
return torch.device('cpu')
device = get_device()
device
Model
The model we will use for this task will be a Resnet18 due to it's realtivly small size and good performance. This could be changed later as part of optimising, either by selecting a smaller model for cheaper training/inference or a larger model to improve accuracy.
model_name = 'resnet18'
learning_rate = 0.001
Loss
CrossEntropyLoss will provide what we need for this binary classifcation task. The loss function can be changed to suit the situation however it is out of scope for this blog post.
loss_name = 'CrossEntropyLoss'
batch_size = 4
num_class = 2
Set Epochs for Training the Model
This is a realatively simple task for the pretrained model, we have selected 10 epochs so this should provide plenty of time to fine tune. This value would need to be tuned or the use of Early Stoping might be useful but is outside the scope of this post.
last_layer_epochs = 2
all_layer_epochs = 15
Load Pretrained Model
See torchvision.models for details on model's and the PyTorch Modelk zoo.
This method will work with ResNet architecture as the last layer is named 'fc', the code could be adapted to suit other model architectures by referencing the last named layer.
First we load the pretrained model as-is and then modify it to suit our usecase.
model = tv.models.__dict__[model_name](pretrained=True)
model
Create New Output Layer
We replace the single Linear output layer with a Sequential (sub)model that includes 2 linear layers seperated by a ReLU then ending in a softmax function.
The last layer of this model is called 'fc' so we reference this layer by nameto replace it with our new output layers.
More information is availalbe in the Pytorch documentation
model.fc = torch.nn.Sequential(
torch.nn.Linear(in_features=512, out_features=512, bias=True),
torch.nn.ReLU(),
torch.nn.Linear(in_features=512, out_features=2, bias=True),
torch.nn.Softmax(dim = 1)
)
model
model = model.to(device)
Setup the Loss Function
A criterion or loss function object is created using a Pytorch Loss Function and the variables we defined earlier.
loss_fn = torch.nn.__dict__[loss_name]()
Setup optimiser
Using the torch.optim class we create an optimizer object with the value we defined earlier.
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)
data_transforms = tv.transforms.Compose([
tv.transforms.ToTensor(),
tv.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
Get a Dataset
Torchvision has a range of image based datasets available to use, we will be crating CIFAR2 from Cifar10. PyTorch has built in methods to download and access the data using the torchvision.datasets.CIFAR10 class and we will use this as an initial source for our data.
CIFAR10 Has 10 balanced classes which are:
- Airplane
- Automobile
- Bird
- Cat
- Deer
- Dog
- Frog
- Horse
- Ship
- Truck
In this example we will select the first two classes (Planes and Cars) and create a binary classifcation task.
To load and feed data to the data loader PyTorch we extend the torch.utils.data.Dataset class and create the __len__ and __getitem__ methods to return each element of the dataset.
class cifar_bin_dataset(torch.utils.data.Dataset):
def __init__(self, train_set, data_transforms):
'Initialization'
self.data_transform = data_transforms
#Import the cifar10 data, we will be accessing this to get thje data we need
cifar10_dataset_obj = tv.datasets.CIFAR10(root='.', train=train_set, download=True)
self.image_array = cifar10_dataset_obj.data
self.number_class_targets = int(len(cifar10_dataset_obj.targets)/10) #Get the number of examples of each class
plane_idx = [i for i, value in enumerate(cifar10_dataset_obj.targets) if value == 0]
car_idx = [i for i, value in enumerate(cifar10_dataset_obj.targets) if value == 1]
self.idx_list = plane_idx + car_idx
self.target_tensor = np.concatenate(
(np.zeros(self.number_class_targets),
np.ones(self.number_class_targets)),
axis=0
).astype(np.float32)
def __len__(self):
'Denotes the total number of samples'
return len(self.idx_list)
def __getitem__(self, index):
'Generates one sample of data'
img = Image.fromarray(self.image_array[self.idx_list[index]]) #turn the np array into a PIL image to simulate
return data_transforms(img), int(self.target_tensor[index])
train_dataset = cifar_bin_dataset(True, data_transforms)
val_dataset = cifar_bin_dataset(False, data_transforms)
The exploration of data is a vital first step so here we create a function to print out the image and assocated label.
The dataset can be accessed directly using the index from the __getitem\_ methods we created in the __cifar_bindataset__class
We use torchvision.transforms to change the Pytorch tensor into a PIL object that can then be displayed using IPython.display
print('The train_dataset length is', len(train_dataset), 'elements')
print('The val_dataset length is', len(val_dataset), 'elements')
def display_image(dataset, idx):
display_transform = tv.transforms.Compose([tv.transforms.ToPILImage(),tv.transforms.Resize((100,100))])
class_lookup = {0 : 'plane', 1 : 'car'}
print(class_lookup[dataset[idx][1]])
display(display_transform(dataset[idx][0]))
print('-----------\n')
for idx in range (0, len(train_dataset),1000):
display_image(train_dataset, idx)
Create Data Loader
In Pytorch a Dataloader takes a torch.utils.data.Dataset object as an input along with hyper params such as batch size.
The Number of workers (num_workers) is optimised depending on the hardware this is run on, generally optimal settings are slightly less than the number of threads available.
train_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=12)
val_data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
Create Test Function
This function receives a dataloader, model and loss function, then loops through the dataloader calculating loss and accuracy then returning the averages of these for the whole epoch.
Calling torch.no_grad() results in the gradients of the model not being recorded.
running_acc is updated with a calculation that rounds the second column of the prediction and compares that to the ground truth returning the percentage that were correct.
def test_single_epoch(device, dataloader, model, loss_fn):
#zero accumulating values for los and acc
running_acc = 0.0
running_loss = 0.0
#set model to eval state
with torch.no_grad():
# Iterate over data and test
for batch, gnd_truth in dataloader:
#Send to gpu or cpu
batch, gnd_truth = batch.to(device), gnd_truth.to(device)
# forward
preds = model(batch)
#get loss
loss = loss_fn(preds, gnd_truth)
# update running loss and acc
running_acc += (torch.round(preds[:,1]).unsqueeze(1) == gnd_truth.unsqueeze(1)).sum()/ gnd_truth.unsqueeze(1).shape[0]
running_loss += loss.item()
return float(running_loss/len(dataloader)), float(running_acc/len(dataloader))
Create Train Function
Training is very similar to the test function above with some key differences, we call model.train() at the start so the model knows to collect gradients, we also need to clear the gradents at the start of each batch with optimizer.zero_grad() then after calculating the loss we call loss.backward() to calculate the gradients for all of the parameters (tensors) in the model.
Finaly we call optimizer.step() to get the optimizer to iterate over all parameters (tensors) it is set to update, the stored gradents from the last steps are used.
def train_single_epoch(device, dataloader, model, loss_fn, optimizer):
#zero accumulating values for los and acc
running_acc = 0.0
running_loss = 0.0
#set model to train state
model.train()
# Iterate over data and train
for batch, gnd_truth in dataloader:
#Send to gpu or cpu
batch, gnd_truth = batch.to(device), gnd_truth.to(device)
# zero the parameter gradients
optimizer.zero_grad()
# forward
preds = model(batch)
#get loss
loss = loss_fn(preds, gnd_truth)
#backward
loss.backward()
#optimize
optimizer.step()
# update running loss and acc
running_acc += (torch.round(preds[:,1]).unsqueeze(1) == gnd_truth.unsqueeze(1)).sum()/ gnd_truth.unsqueeze(1).shape[0]
running_loss += loss.item()
return model, float(running_loss/len(dataloader)), float(running_acc/len(dataloader))
val_loss, val_acc = test_single_epoch(device, val_data_loader, model, loss_fn)
print('val_loss: {:.4f}'.format(val_loss))
print('val_acc: {:.4f}'.format(val_acc))
Set Trainable Layers
In order to be able to first train the new layers, then fine tune the whole model, we set the layers we are targeting to record gradents using requires_grad = True and the layers we want to leave untouched to requires_grad = False. It also provides an opportunity to view the model in a different way.
for named_param in model.named_parameters():
if 'fc' in named_param[0]:
print('Will train', named_param[0])
named_param[1].requires_grad = True
else:
print('Won\'t train', named_param[0])
named_param[1].requires_grad = False
Training and Validating
The train loop runs the Train then Validate functions and collects the results. This example shows a fairly simple loop that collects the returns from the fuctions and stores them in arrays. The loss is also printed out every other epoch.
This is the part when we can define actions taken between epochs including saving checkpoints of the model or using MLFlow to log metrics from the epoch.
As we aren't using a tool to track the course of the training, we will first setup some arrays to track the model performance and metrics over the course of the training.
train_acc_hist = []
train_loss_hist = []
val_acc_hist = []
val_loss_hist = []
train_loss, train_acc = test_single_epoch(device, train_data_loader, model, loss_fn)
train_acc_hist.append(train_acc)
train_loss_hist.append(train_loss)
val_loss, val_acc = test_single_epoch(device, val_data_loader, model, loss_fn)
val_acc_hist.append(val_acc) #append results to the array we are using for tracking
val_loss_hist.append(val_loss) #append results to the array we are using for tracking
#
print('Inital Model State')
print('train_loss: {:.4f}'.format(train_loss))
print('val_loss: {:.4f}'.format(val_loss))
print('train_acc: {:.4f}'.format(train_acc))
print('val_acc: {:.4f}'.format(val_acc))
print('------------------------------------------------------------------------')
print('starting training for {} epochs '.format(last_layer_epochs))
print('------------------------------------------------------------------------')
for epoch in range(1, last_layer_epochs + 1): # loop over the dataset multiple times
model, train_loss, train_acc = train_single_epoch(device, train_data_loader, model, loss_fn, optimizer)
train_acc_hist.append(train_acc) #append results to the array we are using for tracking
train_loss_hist.append(train_loss) #append results to the array we are using for tracking
val_loss, val_acc = test_single_epoch(device, val_data_loader, model, loss_fn)
val_acc_hist.append(val_acc) #append results to the array we are using for tracking
val_loss_hist.append(val_loss) #append results to the array we are using for tracking
#Print out every epoch
print('finished training for epoch {}'.format(epoch))
print('train_loss: {:.4f}'.format(train_loss))
print('val_loss: {:.4f}'.format(val_loss))
print('train_acc: {:.4f}'.format(train_acc))
print('val_acc: {:.4f}'.format(val_acc))
print('------------------------------------------------------------------------')
train_loss, train_acc = test_single_epoch(device, train_data_loader, model, loss_fn)
train_acc_hist.append(train_acc)
train_loss_hist.append(train_loss)
val_loss, val_acc = test_single_epoch(device, val_data_loader, model, loss_fn)
val_acc_hist.append(val_acc)
val_loss_hist.append(val_loss)
print('Current Model State')
print('train_loss: {:.4f}'.format(train_loss))
print('Val dataset loss: {:.4f}'.format(val_loss))
print('train_acc: {:.4f}'.format(train_acc))
print('Val dataset acc: {:.4f}'.format(val_acc))
for named_param in model.named_parameters():
named_param[1].requires_grad = True
print('------------------------------------------------------------------------')
print('starting training for {} epochs '.format(all_layer_epochs))
print('------------------------------------------------------------------------')
for epoch in range(1, all_layer_epochs + 1): # loop over the dataset multiple times
model, train_loss, train_acc = train_single_epoch(device, train_data_loader, model, loss_fn, optimizer)
train_acc_hist.append(train_acc) #append results to the array we are using for tracking
train_loss_hist.append(train_loss) #append results to the array we are using for tracking
val_loss, val_acc = test_single_epoch(device, val_data_loader, model, loss_fn)
val_acc_hist.append(val_acc) #append results to the array we are using for tracking
val_loss_hist.append(val_loss) #append results to the array we are using for tracking
if (epoch % 3) == 0: #Print out every third epoch
print('finished training for epoch {}'.format(epoch))
print('train_loss: {:.4f}'.format(train_loss))
print('val_loss: {:.4f}'.format(val_loss))
print('train_acc: {:.4f}'.format(train_acc))
print('val_acc: {:.4f}'.format(val_acc))
print('------------------------------------------------------------------------\n\n')
train_loss, train_acc = test_single_epoch(device, train_data_loader, model, loss_fn)
train_acc_hist.append(train_acc)
train_loss_hist.append(train_loss)
val_loss, val_acc = test_single_epoch(device, val_data_loader, model, loss_fn)
val_acc_hist.append(val_acc)
val_loss_hist.append(val_loss)
print('Final Model State')
print('Train dataset loss: {:.4f}'.format(train_loss))
print('Val dataset loss: {:.4f}'.format(val_loss))
print('Train dataset acc: {:.4f}'.format(train_acc))
print('Val dataset acc: {:.4f}'.format(val_acc))
train_loss_plt, = plt.plot(train_loss_hist)
val_loss_plt, = plt.plot(val_loss_hist)
plt.title("Loss Vs Epoch")
plt.legend([train_loss_plt, val_loss_plt], ['Train Loss History', 'Test Loss History'])
train_acc_hist_plt, = plt.plot(train_acc_hist)
val_acc_hist_plt, = plt.plot(val_acc_hist)
plt.title("Accuracy Vs Epoch")
plt.legend([train_acc_hist_plt, val_acc_hist_plt], ['Train Acc History', 'Test Acc History'])
We can see from these charts that the accuracy improves from about random at the start (which is to be expected as the last layers where randomly initialised) to about 78% after training the new layers then to over 90% after fine-tuning the whole model.
This model could be further optimised, specifically things like the batch size, learning rate, model selection and design of the output layers to suit specific tasks.