In my last post I showed how to use torchmetrics to implement segmentation metrics for the Oxford-IIIT pet segmentation dataset. We saw that in addition to the average keyword introduced in the pet breed classification post, the mdmc_average keyword is necessary to compute metrics for image data.

In this post we'll dive deeper into these metrics, explaining the two choices for the mdmc_average parameter, including global and samplewise, as well as giving recommendations for dealing with imbalanced datasets.

The examples below will look primarily at precision and $F1$ score, but note that these metrics can be replaced by recall, dice score, etc.

!pip install pytorch-lightning
!pip install -U git+https://github.com/qubvel/segmentation_models.pytorch
!pip install seaborn
import torch
import functools
import segmentation_models_pytorch as smp
from torchmetrics.functional.classification import precision, f1_score
from torchmetrics.classification import StatScores
from sklearn import metrics

# Set the seed for reproduciblity.
torch.manual_seed(7)

import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

To better understand the metrics, we'll work with a $4$ class problem with $n = 100$ samples. Classes $0$ and $3$ will have a probability of occurence of $\frac{1}{15}$, class $1$ will have a probability of $\frac{2}{3}$, and class $2$ will have a probability of $\frac{1}{5}$. We can generate data having this distribution using torch.multinomial below.

weights = torch.tensor([1, 10, 3, 1], dtype=torch.float)
num_classes = len(weights)
shape = (100,1,256,256)
size = functools.reduce(lambda x, y : x* y, shape)
output = torch.multinomial(weights, size, replacement=True).reshape(shape)
output[70:,:,:,:] = torch.zeros(30, *shape[1:])
target = torch.multinomial(weights, size, replacement=True).reshape(shape)
target[70:,:,:,:] = torch.zeros(30, *shape[1:])

For example, a subset of the output looks like:

output[0,:,:10,:10]
tensor([[[1, 1, 2, 1, 2, 2, 1, 1, 2, 2],
         [1, 1, 1, 1, 1, 3, 1, 1, 1, 1],
         [1, 1, 3, 2, 1, 2, 1, 1, 1, 1],
         [0, 1, 1, 2, 3, 1, 1, 1, 1, 2],
         [1, 0, 1, 1, 1, 1, 1, 1, 1, 3],
         [1, 1, 1, 2, 0, 1, 1, 0, 1, 1],
         [1, 1, 1, 0, 1, 1, 2, 1, 2, 1],
         [2, 1, 1, 1, 2, 1, 2, 1, 3, 2],
         [3, 1, 1, 3, 1, 2, 1, 1, 1, 1],
         [2, 3, 0, 1, 1, 1, 1, 2, 2, 1]]])

First we can collapse the image dimensions, $H$ and $W$, and then calculate metrics as for multiclass classification. This is precisely what happens when we choose mdmc_average global.

precision(output, target,num_classes=num_classes,average="macro",mdmc_average="global").item()
0.4517214596271515

For comparisons sake, in scikit-learn we have:

metrics.precision_score(target.reshape((-1)).numpy(),output.reshape((-1)).numpy(), average="macro")
0.4517214441963613

Then the different options for average can be chosen, including micro, macro, and weighted.

In contrast, the image dimensions can be treated separately, which is called the macro-imagewise reduction:

  1. For each image and class the confusion table is computed over all pixels in an image.
  2. Then the metric is computed for each image and class, as if it were a binary classifier.
  3. The metrics are finally averaged over the images and classes.

This is the most natural way to calculate metrics like the Jaccard index (intersection over union) for example. Unfortunately the jaccard index can't be calculated this way using torchmetrics. However the $F1$/Dice Score can be calculated using torchmetrics, and it's equivalent to the Jaccard index1:

f1_score(output, target,num_classes=num_classes,average="macro",mdmc_average="samplewise").item()
0.2497853934764862

However if we calculate the $F1$ score using the segmentation models library, we get:

tp, fp, fn, tn = smp.metrics.get_stats(output.long(), target.long(), mode='multiclass', num_classes=num_classes)
smp.metrics.f1_score(tp, fp, fn, tn, reduction="macro-imagewise").item()
0.47478538751602173

This is because our dataset has many images with no targets (recall that we zeroed out several images). Thus the $F1$ score for non-background classes reduces to $\frac{0}{0}$. smp replaces occurences of $\frac{0}{0}$ by $1$, while torchmetrics replaces $\frac{0}{0}$ by $0$. If we pass zero_division=0 to the segmentation models library, we get the same value as torchmetrics:

tp, fp, fn, tn = smp.metrics.get_stats(output.long(), target.long(), mode='multiclass', num_classes=num_classes)
smp.metrics.f1_score(tp, fp, fn, tn, reduction="macro-imagewise", zero_division=0).item()
0.2497853934764862

This we why we recommend avoiding mdmc_average equal to samplewise, and calculating the metrics like for regular multiclass classifiers instead.

In conclusion when dealing with balanaced datasets, accuracy using the micro average plus mdmc_average global is sufficient, while the $F1$ score with the weighted average plus mdmc_average global is more accurate for imbalanaced datasets.