1.Dice Loss
Dice loss 来自文献[1],是从Dice系数推广得到的损失函数。
Dice 系数是一种集合相似度度量函数,是从区域角度衡量两个集合的相似度。(CE Loss是 从概率分布角度)
Dice 系数值域为 [0, 1] ,两个集合完全重叠时为1, 完全不重叠时为0,计算公式如下
Dice loss = ,值域[0, 1],loss值越小,重合度越高
分母的计算:|A| 和 |B|分别表示A、B的元素个数
分子的计算:A和B的交集,用点乘。

# -*- coding: utf-8 -*-
"""
# @file name : camvid_config.py
# @author : TingsongYu https://github.com/TingsongYu
# @date : 2020-03-12
# @brief : dice loss
"""
import torch
import torch.nn as nn
class DiceLoss(nn.Module):
"""
soft dice loss, 直接使用预测概率而不是使用阈值或将它们转换为二进制mask
"""
def __init__(self, epsilon=1e-5):
super(DiceLoss, self).__init__()
self.epsilon = epsilon
def forward(self, predict, target):
assert predict.size() == target.size(), "the size of predict and target must be equal."
num = predict.size(0)
# pred不需要转bool变量,如https://github.com/yassouali/pytorch-segmentation/blob/master/utils/losses.py#L44
# soft dice loss, 直接使用预测概率而不是使用阈值或将它们转换为二进制mask
pred = torch.sigmoid(predict).view(num, -1)
targ = target.view(num, -1)
intersection = (pred * targ).sum() # 利用预测值与标签相乘当作交集
union = (pred + targ).sum()
score = 1 - 2 * (intersection + self.epsilon) / (union + self.epsilon)
return score
if __name__ == "__main__":
fake_out = torch.tensor([7, 7, -5, -5], dtype=torch.float32)
fake_label = torch.tensor([1, 1, 0, 0], dtype=torch.float32)
loss_f = DiceLoss()
loss = loss_f(fake_out, fake_label)
print(loss)
2.Focal Loss
Focal Loss:针对Two-Stage目标检测RPN网络中正负样本严重不均衡及困难样本提出的Loss,是在CE loss 基础上改进而来。CE LOSS如下:

解决不均衡:增加类别权重αt

解决困难样本:增加难度权重γ

最终Focal loss 公式如下:


# -*- coding: utf-8 -*-
"""
# @file name : focal_loss.py
# @author : TingsongYu https://github.com/TingsongYu
# @date : 2020-03-12
# @brief : 标准的 focal loss
"""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
class FocalLoss(nn.Module):
def __init__(self, gamma=2, alpha=None, ignore_index=255, size_average=True):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.size_average = size_average
self.CE_loss = nn.CrossEntropyLoss(ignore_index=ignore_index, weight=alpha)
def forward(self, output, target):
logpt = self.CE_loss(output, target)
pt = torch.exp(-logpt) # 因CE中取了log,所以要exp回来,就得到概率。因为输入并不是概率,CEloss中自带softmax转为概率形式
loss = ((1-pt)**self.gamma) * logpt
if self.size_average:
return loss.mean()
return loss.sum()
if __name__ == "__main__":
target = torch.tensor([1], dtype=torch.long)
gamma_lst = [0, 0.5, 1, 2, 5]
loss_dict = {}
for gamma in gamma_lst:
focal_loss_func = FocalLoss(gamma=gamma)
loss_dict.setdefault(gamma, [])
for i in np.linspace(0.5, 10.0, num=30):
outputs = torch.tensor([[5, i]], dtype=torch.float) # 制造不同概率的输出
prob = F.softmax(outputs, dim=1) # 由于pytorch的CE自带softmax,因此想要知道具体预测概率,需要自己softmax
loss = focal_loss_func(outputs, target)
loss_dict[gamma].append((prob[0, 1].item(), loss.item()))
for gamma, value in loss_dict.items():
x_prob = [prob for prob, loss in value]
y_loss = [loss for prob, loss in value]
plt.plot(x_prob, y_loss, label="γ="+str(gamma))
plt.title("Focal Loss")
plt.xlabel("probability of ground truth class")
plt.ylabel("loss")
plt.legend()
plt.show()
3.BCE Loss
(Binary Cross-Entropy Loss) 是用于二分类问题的损失函数。它用于评估预测值和实际标签之间的差异。计算过程如下:

import torch
import torch.nn as nn
# 创建一个示例输入(预测值)和标签
predictions = torch.tensor([0.1, 0.9, 0.8, 0.3], dtype=torch.float32)
labels = torch.tensor([0, 1, 1, 0], dtype=torch.float32)
# 初始化 BCELoss
criterion = nn.BCELoss()
# 计算损失
loss = criterion(predictions, labels)
print(f"Binary Cross-Entropy Loss: {loss.item()}")
4.Bce+Dice (训练效果好)
采用BCE + Dice ,权重比例1:1
import torch
import torch.nn as nn
from losses.dice_loss import DiceLoss
class BCEDiceLoss(nn.Module):
def __init__(self, **kwargs):
super(BCEDiceLoss, self).__init__()
self.bce_func = nn.BCEWithLogitsLoss(**kwargs) # *args和**kwargs,python可变长参数
self.dice_func = DiceLoss()
def forward(self, predict, target):
loss_bce = self.bce_func(predict, target)
loss_dice = self.dice_func(predict, target)
return loss_dice + loss_bce
if __name__ == "__main__":
fake_out = torch.tensor([1, 1, -1, -1], dtype=torch.float32)
fake_label = torch.tensor([1, 1, 0, 0], dtype=torch.float32)
loss_f = BCEDiceLoss()
loss = loss_f(fake_out, fake_label)
print(loss)
[1]《V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation》
1万+

被折叠的 条评论
为什么被折叠?



