深度学习-交叉熵

交叉熵(Cross-Entropy) 是信息论中的一个核心概念,在深度学习中,它是最常用、最重要的损失函数之一,尤其擅长处理分类问题

简单理解,交叉熵可以用来衡量两个概率分布之间的差异。在模型训练中,它衡量的就是:模型预测的概率分布,与真实的概率分布(通常是One-hot编码的标签)之间的差距。

  • 预测越准,交叉熵损失越小。
  • 预测越离谱,交叉熵损失越大。

1. 从公式理解

对于单个样本,交叉熵损失的公式如下:

[
\text{CrossEntropy} = -\sum_{i} y_i \log(p_i)
]

  • ( y_i ):表示第 ( i ) 个类别的真实标签。在分类任务中,真实类别为1,其他为0。
  • ( p_i ):表示模型预测样本属于第 ( i ) 个类别的概率,取值范围在0到1之间。
  • ( \log ):是自然对数。

由于 ( y_i ) 只在真实类别(比如第 ( c ) 类)上为1,其他全为0,所以这个公式可以简化为:

[
\text{Loss} = -\log(p_c)
]

这个简化公式非常直观地说明了交叉熵的工作原理:损失的大小,完全由模型给正确类别预测的概率 ( p_c ) 决定。

  • 当模型预测正确类别的概率 ( p_c = 1 ) 时,( -\log(1) = 0 ),损失为0。
  • 当 ( p_c = 0.5 ) 时,( -\log(0.5) \approx 0.693 )。
  • 当 ( p_c = 0.1 ) 时,( -\log(0.1) \approx 2.302 )。
  • 当 ( p_c ) 趋近于 0 时,( -\log(p_c) ) 会趋近于正无穷。

2. 一个具体例子

假设你有一个图像分类任务,图片是一只。分类类别有:猫、狗、鸟。

  • 真实标签 (One-hot 编码)[1, 0, 0] (猫)
  • 模型A的预测 (很准)[0.9, 0.05, 0.05] 。损失 = ( -\log(0.9) \approx 0.105 )。
  • 模型B的预测 (不太准)[0.4, 0.5, 0.1] 。损失 = ( -\log(0.4) \approx 0.916 )。
  • 模型C的预测 (完全错误)[0.05, 0.9, 0.05] 。损失 = ( -\log(0.05) \approx 3.0 )。

可以看到,模型A(预测正确概率高)的损失很小,而模型C(预测错误)的损失非常大。通过反向传播,交叉熵损失函数会驱使模型不断提高对正确类别的预测概率。

3. 为什么在分类任务中如此有效?

交叉熵之所以被广泛使用,主要有三个优势:

  1. 梯度更大,学习更快
    与均方误差(MSE)等损失函数相比,当模型的预测结果与真实标签相差甚远时,交叉熵能提供一个很大的梯度,模型会进行大幅度的修正,从而快速改进。而MSE在初期错误率很高时梯度可能会很小,导致学习缓慢。

  2. 结合Softmax,天然适配多分类
    在神经网络中,最后一层输出的原始数值(logits)通常无法直接视为概率。交叉熵损失函数常常与 Softmax 激活函数配合使用。Softmax能把logits转换成和为1的概率分布,这和交叉熵对输入的预期(概率分布)是天作之合。

4. CrossEntropyLoss vs. BCELoss

在使用PyTorch等框架时,你会遇到几种名称相似但功能不同的交叉熵损失,需要注意区分:

损失函数适用任务最后一层激活函数标签形式说明
nn.CrossEntropyLoss多分类 (互斥类别)无需 (或Linear)类别索引 (如 1最常用,内部融合了LogSoftmax和NLLLoss,不需要在输出层再加Softmax。
nn.BCELoss二分类 或多标签Sigmoid0/1 概率值需要手动在输出层加Sigmoid。
nn.BCEWithLogitsLoss二分类 或多标签无需 (或Linear)0/1 数值推荐,比BCELoss更数值稳定,内部融合了Sigmoid和BCELoss。

代码示例(使用 nn.CrossEntropyLoss 进行多分类):

import torch
import torch.nn as nn

# 假设有3个类别,2个样本
logits = torch.tensor([[2.0, 1.0, 0.1]# 模型对样本1的输出
                       [0.5, 2.5, 0.3]])      # 模型对样本2的输出
# 真实标签:样本1属于第0类,样本2属于第1类
labels = torch.tensor([01])

# 定义损失函数
criterion = nn.CrossEntropyLoss()

# 计算损失
loss = criterion(logits, labels)
print(f'Cross Entropy Loss: {loss.item()}')   # 输出: 例如 0.774

总结

一句话总结
交叉熵是分类任务中衡量“预测概率分布”与“真实概率分布”差异的标准工具。模型通过最小化交叉熵来让自己对正确类别的预测概率趋近于1,从而学会正确分类。

简单来说,交叉熵是一个聪明的“教练”:当你学得差时,它给你严厉的惩罚(大损失,大梯度);当你学得好时,它给你温柔的鼓励(小损失,小梯度),引导模型快速收敛。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值