Pet Breed Classification
An end-to-end workflow for image classification on the Oxford IIIT pet dataset using PyTorch, PyTorch Lightning, Torchmetrics and Tensorboard. This notebook trains state of the art image classification models on the Oxford IIIT pet dataset, and shows how to use torchmetrics to measure their quality.
In this example, we'll build an image classifier to detect the 37 different dog and cat breeds in the Oxford-IIIT Pet Dataset. This is a well studied example, so our main purpose here will be to illustrate its solution using Pytorch Lightning and Torchmetrics.
Torchmetrics is a new library that has many metrics for classification in Pytorch.
- It allows for easy computation over batches.
- Rigorously tested.
- A standardized interface to increase reproduciblity.
- And much more...
We'll install it below with Pytorch Lightning.
!pip install pytorch-lightning
!pip install lightning-bolts
!pip install seaborn
import os
import torch
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pytorch_lightning import LightningModule, Trainer, seed_everything
from pl_bolts.transforms.dataset_normalizations import imagenet_normalization
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import random_split
seed_everything(7)
Let's quickly load and visualize the dataset. A more thorough inspection can be done using Voxel 51.
vis_dataset = torchvision.datasets.OxfordIIITPet(root="./data",split="trainval",download=True)
for i in range(5):
image, label = vis_dataset[random.choice(range(len(vis_dataset)))]
fig = plt.figure()
plt.imshow(image)
plt.title(f"Image of a {vis_dataset.classes[label]}")
plt.axis("off")
Now let's make transforms for the training, validation, and test set. We'll use torchvision
s RandAugment
which is an automated data augmentation strategy that improves the classifier accuracy by a few percentage points.
Also since we'll be using a ResNet model pretrained on ImageNet, we'll also use ImageNet normalization.
train_transform = transforms.Compose([
transforms.Resize((256,256)),
transforms.CenterCrop(224),
transforms.RandAugment(),
transforms.ToTensor(),
imagenet_normalization()
])
val_transform = transforms.Compose([
transforms.Resize((256,256)),
transforms.CenterCrop(224),
transforms.ToTensor(),
imagenet_normalization()
])
Most of the remaining steps are fairly standard for obtaining a fixed train/val/test dataset. Note that the split is fixed since we used seed_everything
earlier.
train_dataset = torchvision.datasets.OxfordIIITPet(root="./data", split="trainval", transform=train_transform, download=False)
test_dataset = torchvision.datasets.OxfordIIITPet(root="./data", split="test",transform=val_transform, download=False)
classes = train_dataset.classes
num_classes = len(classes)
BATCH_SIZE = 32
split_len = int(0.9*len(train_dataset))
train_dataset, val_dataset = random_split(train_dataset, [split_len, len(train_dataset)-split_len], generator=torch.Generator().manual_seed(7))
print('Length of train dataset: ', len(train_dataset))
print('Length of validation dataset: ', len(val_dataset))
print('Length of test dataset: ', len(test_dataset))
train_dataloader = torch.utils.data.DataLoader(train_dataset, shuffle=True, batch_size=BATCH_SIZE, num_workers=os.cpu_count())
val_dataloader = torch.utils.data.DataLoader(val_dataset, shuffle=False, batch_size=BATCH_SIZE, num_workers=os.cpu_count())
test_dataloader = torch.utils.data.DataLoader(val_dataset, shuffle=False, batch_size=BATCH_SIZE, num_workers=os.cpu_count())
Let's take a look at the class statistics in the train dataset.
class_counts = np.array([0]*num_classes)
for batch in train_dataloader:
_, labels = batch
for label in labels:
class_counts[label] += 1
df_classes = pd.DataFrame({"classes" : classes, "class_counts" : class_counts})
As we'll see the classes are balanced, which simplifies the treatment of this problem.
df_classes.plot(kind="barh", x="classes", y="class_counts",figsize=(8,12))
plt.show()
Now we'll write the Pytorch Lightning module to create and train the model.
To train our model, we'll freeze pretrained models from torchvision and add a linear head for fine tuning. For illustrative purposes, we'll use a ResNet34 model, but you can also try ResNext50, or RegNetX to obtain slightly higher accuracy. The model will be set as the model
attribute in the LightningModule class.
Since the LightningModule inherits from the Pytorch Module, we also have to set the forward
method of the neural network. The backward
method is automatically generated.
The training is relatively simple, using an Adam optimizer, which is set in the configure_optimizers
method.
The training, validation, and test steps for a batch of data are configured via the training_step
, validation_step
, and test_step
respectively. We'll use a single custom method called evaluate
to handle the validation and test steps. See the Lightning documentation for more details.
Finally, we'll use Accuracy and F1Score with torchmetrics
to measure the performance of our model.
- Since we're using the object oriented API, which preserves state, we have to create separate objects for the train, val and test sets.
- We wrap
Accuracy
andF1Score
in theMetricCollection
, which will automatically share common computations between the two metrics, and simplify the code. - For multiclass classifiers, the average option passed to the metrics (micro or macro) affects the calculation. The choices below were chosen to be consistent with scikit-learn, and work well with a balanaced dataset. I'll delve deeper into these nuances in following posts.
import torch.nn as nn
from torchmetrics import MetricCollection, Accuracy, F1Score
from torchvision.models import resnet34, ResNet34_Weights
from torchvision.models import resnext50_32x4d, ResNeXt50_32X4D_Weights
from torchvision.models import regnet_x_3_2gf, RegNet_X_3_2GF_Weights
class PetModel(LightningModule):
def __init__(self, arch, learning_rate, num_classes):
super().__init__()
self.save_hyperparameters()
self.example_input_array = torch.zeros((BATCH_SIZE, 3, 224,224))
# Setup the model.
self.create_model(arch)
# Setup the losses.
self.loss = nn.CrossEntropyLoss()
# Setup the metrics.
self.train_metrics = MetricCollection({"train_acc" : Accuracy(num_classes=num_classes, average="micro"),
"train_f1" : F1Score(num_classes=num_classes, average="macro")})
self.val_metrics = MetricCollection({"val_acc" : Accuracy(num_classes=num_classes, average="micro"),
"val_f1" : F1Score(num_classes=num_classes, average="macro")})
self.test_metrics = MetricCollection({"test_acc" : Accuracy(num_classes=num_classes, average="micro"),
"test_f1" : F1Score(num_classes=num_classes, average="macro")})
def create_model(self, arch):
"""
Setup a model for fine tuning.
"""
in_dimension = 512
if arch == "resnext50_32x4d":
self.model = resnext50_32x4d(weights=ResNeXt50_32X4D_Weights.DEFAULT)
in_dimension = 4*512
elif arch == "regnet_x_3_2gf":
self.model = regnet_x_3_2gf(weights=RegNet_X_3_2GF_Weights.DEFAULT)
in_dimension = 1008
else:
self.model = resnet34(weights=ResNet34_Weights.DEFAULT)
in_dimension = 512
for param in self.model.parameters():
param.requires_grad = False
self.model.fc = nn.Linear(in_dimension, num_classes)
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
images, labels = batch
logits = self(images)
preds = torch.argmax(logits, dim=1)
loss = self.loss(logits, labels)
self.train_metrics(preds, labels)
self.log("train_acc", self.train_metrics["train_acc"], prog_bar=True)
self.log("train_f1", self.train_metrics["train_f1"], prog_bar=True)
self.log("train_loss", loss, prog_bar=True)
return loss
def evaluate(self, batch, stage=None):
images, labels = batch
logits = self(images)
preds = torch.argmax(logits, dim=1)
loss = nn.CrossEntropyLoss()(logits, labels)
if stage == "val":
self.val_metrics(preds,labels)
self.log("val_acc", self.val_metrics["val_acc"], prog_bar=True)
self.log("val_f1", self.val_metrics["val_f1"], prog_bar=True)
self.log("val_loss", loss, prog_bar=True)
elif stage == "test":
self.test_metrics(preds,labels)
self.log("test_acc", self.test_metrics["test_acc"], prog_bar=True)
self.log("test_f1", self.test_metrics["test_f1"], prog_bar=True)
self.log("test_loss", loss, prog_bar=True)
def validation_step(self, batch, batch_idx):
return self.evaluate(batch, "val")
def test_step(self, batch, batch_idx):
return self.evaluate(batch, "test")
def configure_optimizers(self):
return torch.optim.Adam(params=self.parameters(), lr=self.hparams.learning_rate)
model = PetModel("resnet34", 1e-3, num_classes)
We'll log metrics to TensorBoard using the TensorBoardLogger
, and save the best model, measured using accuracy, with the ModelCheckpoint
.
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
name = "oxfordpet" + "_" + model.hparams.arch
logger = TensorBoardLogger(save_dir="lightning_logs",
name=name,
log_graph=True,
default_hp_metric=False)
callbacks = [ModelCheckpoint(monitor="val_acc",save_top_k=1, mode="max") ]
Initialize the trainer by requiring training to be done on the GPU, max epochs of 12, setting the logger, and callbacks.
trainer = Trainer(accelerator='gpu',
devices=1,
max_epochs=12,
logger=logger,
callbacks=callbacks,
fast_dev_run=False)
Visualize training progress in TesnorBoard.
%load_ext tensorboard
%tensorboard --logdir=./lightning_logs --bind_all
Finally, fit the model on the training dataset while saving the best model based on performance on the validation dataset.
trainer.fit(model,
train_dataloaders=train_dataloader,
val_dataloaders=val_dataloader)
Let's load the best model and measure its performance.
best_model_path = trainer.checkpoint_callback.best_model_path
print(best_model_path)
best_model = PetModel.load_from_checkpoint(checkpoint_path=best_model_path)
trainer.test(best_model,dataloaders=test_dataloader)
Now visualize the confusion matrix.
from torchmetrics import ConfusionMatrix
confmat = ConfusionMatrix(num_classes=num_classes, normalize='true')
prev_device = best_model.device.type
device = torch.device("cuda" if torch.cuda.is_available else "cpu")
confmat.to(device)
best_model.to(device)
with torch.no_grad():
best_model.eval()
for batch in test_dataloader:
images, labels = batch
images, labels = images.to(device), labels.to(device)
logits = best_model(images)
preds = torch.argmax(logits, dim=1)
confmat.update(preds, labels)
best_model.to(prev_device)
1
import seaborn as sns
cmat = confmat.compute().cpu().numpy()
fig, ax = plt.subplots(figsize=(20,20))
df_cm = pd.DataFrame(cmat, index = range(len(classes)), columns=range(len(classes)))
ax = sns.heatmap(df_cm, annot=True, fmt='.2g', cmap='Spectral')
ax.set_yticklabels([classes[i] for i in range(len(classes))], rotation=0)
ax.set_xticklabels([classes[i] for i in range(len(classes))], rotation=90)
ax.set_xlabel("Predicted label")
ax.set_ylabel("Actual label")
plt.show(block=False)