Categorical cross entropy CCE and Dice index DICE are popular loss functions for training of neural networks for semantic segmentation. In medical field images being analyzed consist mainly of background pixels with a few pixels belonging to objects of interest. Such cases of high class imbalance cause networks to be biased towards background when trained with CCE. To account for that, weighting of foreground and background pixels can be applied. In contrast to CCE, usage of DICE doesn’t require weighting to successfully train models with imbalanced datasets1.
Notation: In the following, and denote th channel at the th pixel location of the reference labels and neural network softmax output, respectively.
is a one-hot vector of length with at the reference class location and elsewhere.
I use to denote the total channel count, to denote total pixel count in a mini-batch and as a small constant plugged to avoid numerical problems.
The code examples below use
tensorflow and assume that models are fed with 4-D tensors of shape
(batch_dim, y_dim, x_dim, channel_dim).
Softmax output: The loss functions are computed on the softmax output which interprets the model output as unnormalized log probabilities and squashes them into range such that for a given pixel location .
Categorical Cross Entropy
Categorical cross entropy sums negative logs of output probabilities for the correct class. Formally, it is defined as:
def cce_loss(softmax_output, labels): eps = 1e-7 log_p = -tf.log(tf.clip_by_value(softmax_output, eps, 1-eps)) loss = tf.reduce_sum(labels * log_p, axis=-1) return tf.reduce_mean(loss)
The Dice loss function DICE can be defined as:
or using squares in the denominator (DICE_SQUARE) as proposed by Milletari1:
is used to avoid division by 0 (denominator) and to learn from patches containing no pixels of th class in the reference (nominator). The multiplication by gives a nice property, that the loss is within regardless of the channel count. Optionally, the dice loss can be computed only for foreground channels (DICEFG, DICEFG_SQUARE), because it punishes false positives.
def dice_loss(softmax_output, labels, ignore_background=False, square=False): if ignore_background: labels = labels[..., 1:] softmax_output = softmax_output[..., 1:] axis = (0,1,2) eps = 1e-7 nom = (2 * tf.reduce_sum(labels * softmax_output, axis=axis) + eps) if square: labels = tf.square(labels) softmax_output = tf.square(softmax_output) denom = tf.reduce_sum(labels, axis=axis) + tf.reduce_sum(softmax_output, axis=axis) + eps return 1 - tf.reduce_mean(nom / denom)
I prepared a toy segmentation task with one foreground class. I used
np.random.rand to generate input images.
A pixel is foreground if its value is bigger than some threshold . By tinkering with we can set the balance between output classes:
- = 0.95: 95% bg, 5% fg
- = 0.5: 50% bg, 50% fg
To investigate the behavior of different loss functions I trained a model with one of them for various bg/fg ratios and saved loss values and global gradient norms of all of them. I also trained models with and without batch normalization (BN) before each nonlinearity and used ADAM and SGD optimizers.
You can find my code used for the experiments here2. Feel free to take a look at it and I encourage you to run further experiments using the ipython notebook. Some results are plotted below (column name denotes the loss function used for training). All plots can be found here3.
Please zoom in for better readability.
None of the models converged when optimized with SGD, whether BN was used or not. Optimizing with ADAM made all models without BN converge except for the one trained with DICEFG on the 95% fg data. Using BN before each activation improved models accuracy and helped to train the model with DICEFG on the 95% fg data successfully. When training with DICEFG, the CCE loss and gradient norm skyrocketed in the beginning for the 50% fg dataset.
It is interesting to note, that the gradient norms for the 95% fg case are smaller than those for the 5% fg case. It is a strange observation, because, for example from the point of view of the DICE and CCE loss functions, both situations should be indistinguishable, since they take bg and fg equally into account.
For this toy segmentation task, all models with batch normalization achieved a good accuracy on all bg/fg configurations regardless of the used loss function. Batch normalization improved models performance and was essential to make the model trained with DICEFG converge on the 95% fg dataset.