Pytorch Train/Test Using Pretrained Model and Artificial Data
This is a demonstration of using Pytorch to fine tune a pre-trained model on an artificial data set (zeros and ones). This code can be a good place to start from when creating a new Neural Network in PyTorch as it demonstrates many of the basic concepts in a functional context.
This model trains very fast due to the extremely simple nature of the task and will run on either CPU or GPU auto selecting GPU if available.
Included are the basic elements:
- Selecting GPU/CPU
- Building a Pytorch dataset
- Setting hyper parameters
- Importing a pre-trained model architecture
- Setting optimizer and criterion
- Creating Pytroch DataLoader
- Creating Train/Test functions
- Training then Testing using a loop with outputs
This code was created and tested using Enviroment E037
import torch, torchvision, torch.optim as optim
def get_device():
"""Set as GPU if available, else set CPU"""
if torch.cuda.is_available():
return torch.device('cuda')
else:
return torch.device('cpu')
class one_zero_dataset(torch.utils.data.Dataset):
"""Setup zero/one gen as a dataset"""
def __len__(self):
return 100 #Set dataset size to be 100
def __getitem__(self, index):
#Create tensors of the same shape as Imagenet with either a 1 or 0
if (index%2) == 0:
data = torch.zeros((3, 224, 224), dtype=torch.float32)
gnd_trth = float(0)
else:
data = torch.ones((3, 224, 224), dtype=torch.float32)
gnd_trth = float(1)
return data, gnd_trth
def train_single_epoch(device, dataloader, model, criterion, optimizer):
#zero accumulating values for los and acc
running_acc = 0.0
running_loss = 0.0
#set model to train state
model.train()
# Iterate over data and train
for batch, gnd_truth in dataloader:
#Send to gpu or cpu
batch, gnd_truth = batch.to(device), gnd_truth.to(device)
# zero the parameter gradients
optimizer.zero_grad()
# forward
preds = model(batch)
#get loss
loss = criterion(preds[:,1], gnd_truth.float())
#backward
loss.backward()
#optimize
optimizer.step()
# update running loss and acc
running_acc += (torch.round(preds[:,1]).unsqueeze(1) == gnd_truth.unsqueeze(1)).sum()/ gnd_truth.unsqueeze(1).shape[0]
running_loss += loss.item()
return model, running_loss/len(dataloader), running_acc/len(dataloader)
def test_single_epoch(device, dataloader, model, criterion, optimizer):
#zero accumulating values for los and acc
running_acc = 0.0
running_loss = 0.0
#set model to eval state
with torch.no_grad():
# Iterate over data and test
for batch, gnd_truth in dataloader:
#Send to gpu or cpu
batch, gnd_truth = batch.to(device), gnd_truth.to(device)
# forward
preds = model(batch)
#get loss
loss = criterion(preds[:,1], gnd_truth.float())
# update running loss and acc
running_acc += (torch.round(preds[:,1]).unsqueeze(1) == gnd_truth.unsqueeze(1)).sum()/ gnd_truth.unsqueeze(1).shape[0]
running_loss += loss.item()
return running_loss/len(dataloader), running_acc/len(dataloader)
device = get_device()
learning_rate = 0.0001
batch_size = 10
epochs = 10
model = torchvision.models.resnet18(pretrained=True)
model.fc = torch.nn.Sequential(torch.nn.Linear(in_features=512, out_features=2, bias=True), torch.nn.Softmax(dim = 1))
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = torch.nn.BCEWithLogitsLoss()
train_data_loader = torch.utils.data.DataLoader(one_zero_dataset(), batch_size=batch_size, shuffle=True, num_workers=12)
test_data_loader = torch.utils.data.DataLoader(one_zero_dataset(), batch_size=batch_size, shuffle=True, num_workers=2)
print('------------------------------------------------------------------------')
print('starting training for {} epochs '.format(epochs))
# loop over the dataset multiple times training and then testing
for epoch in range(1, epochs+1):
model, train_loss, train_acc = train_single_epoch(device, train_data_loader, model, criterion, optimizer)
test_loss, test_acc = test_single_epoch(device, test_data_loader, model, criterion, optimizer)
if (epoch % 2) == 0: #Print every 2 epochs
print('finished training for epoch {}'.format(epoch))
print('train_loss: {:.4f}'.format(train_loss))
print('test_loss: {:.4f}'.format(test_loss))
print('train_acc: {:.4f}'.format(train_acc))
print('test_acc: {:.4f}'.format(test_acc))
print('\n----------------------------------\n')