Multiclass Segmentation Metrics
Multiclass segmentation metrics with torchmetrics, highlighting the difference between micro, macro, and macro-imagewise metrics.
In my last post I showed how to use torchmetrics
to implement segmentation metrics for the Oxford-IIIT pet segmentation dataset. We saw that in addition to the average
keyword introduced in the pet breed classification post, the mdmc_average
keyword is necessary to compute metrics for image data.
In this post we'll dive deeper into these metrics, explaining the two choices for the mdmc_average
parameter, including global and samplewise, as well as giving recommendations for dealing with imbalanced datasets.
The examples below will look primarily at precision and $F1$ score, but note that these metrics can be replaced by recall, dice score, etc.
!pip install pytorch-lightning
!pip install -U git+https://github.com/qubvel/segmentation_models.pytorch
!pip install seaborn
import torch
import functools
import segmentation_models_pytorch as smp
from torchmetrics.functional.classification import precision, f1_score
from torchmetrics.classification import StatScores
from sklearn import metrics
# Set the seed for reproduciblity.
torch.manual_seed(7)
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
To better understand the metrics, we'll work with a $4$ class problem with $n = 100$ samples. Classes $0$ and $3$ will have a probability of occurence of $\frac{1}{15}$, class $1$ will have a probability of $\frac{2}{3}$, and class $2$ will have a probability of $\frac{1}{5}$. We can generate data having this distribution using torch.multinomial
below.
weights = torch.tensor([1, 10, 3, 1], dtype=torch.float)
num_classes = len(weights)
shape = (100,1,256,256)
size = functools.reduce(lambda x, y : x* y, shape)
output = torch.multinomial(weights, size, replacement=True).reshape(shape)
output[70:,:,:,:] = torch.zeros(30, *shape[1:])
target = torch.multinomial(weights, size, replacement=True).reshape(shape)
target[70:,:,:,:] = torch.zeros(30, *shape[1:])
For example, a subset of the output looks like:
output[0,:,:10,:10]
First we can collapse the image dimensions, $H$ and $W$, and then calculate metrics as for multiclass classification. This is precisely what happens
when we choose mdmc_average
global.
precision(output, target,num_classes=num_classes,average="macro",mdmc_average="global").item()
For comparisons sake, in scikit-learn we have:
metrics.precision_score(target.reshape((-1)).numpy(),output.reshape((-1)).numpy(), average="macro")
Then the different options for average
can be chosen, including micro, macro, and weighted.
In contrast, the image dimensions can be treated separately, which is called the macro-imagewise reduction:
- For each image and class the confusion table is computed over all pixels in an image.
- Then the metric is computed for each image and class, as if it were a binary classifier.
- The metrics are finally averaged over the images and classes.
This is the most natural way to calculate metrics like the Jaccard index (intersection over union) for example. Unfortunately the jaccard index can't be calculated this way using torchmetrics
. However the $F1$/Dice Score can be calculated using torchmetrics
, and it's equivalent to the Jaccard index1:
f1_score(output, target,num_classes=num_classes,average="macro",mdmc_average="samplewise").item()
However if we calculate the $F1$ score using the segmentation models library, we get:
tp, fp, fn, tn = smp.metrics.get_stats(output.long(), target.long(), mode='multiclass', num_classes=num_classes)
smp.metrics.f1_score(tp, fp, fn, tn, reduction="macro-imagewise").item()
This is because our dataset has many images with no targets (recall that we zeroed out several images). Thus the $F1$ score
for non-background classes reduces to $\frac{0}{0}$. smp
replaces occurences of $\frac{0}{0}$ by $1$, while torchmetrics
replaces $\frac{0}{0}$ by $0$.
If we pass zero_division=0
to the segmentation models library, we get the same value as torchmetrics
:
tp, fp, fn, tn = smp.metrics.get_stats(output.long(), target.long(), mode='multiclass', num_classes=num_classes)
smp.metrics.f1_score(tp, fp, fn, tn, reduction="macro-imagewise", zero_division=0).item()
This we why we recommend avoiding mdmc_average
equal to samplewise, and calculating the metrics like for regular multiclass classifiers instead.
In conclusion when dealing with balanaced datasets, accuracy using the micro average plus mdmc_average
global is sufficient,
while the $F1$ score with the weighted average plus mdmc_average
global is more accurate for imbalanaced datasets.