09 SFT Training Loop

LLM 算子深度解析:09 SFT Training Loop — 两行代码决定你的 SFT 是"微调"还是"背诵"


1. SFT 训练的核心矛盾:教会模型回答,不是教会模型提问

1.1 预训练和微调的本质区别

预训练阶段,模型的任务是"给定前文,预测下一个字"。训练数据是一本一本书、一篇文章一篇文章——每一个 token 都要算 Loss,因为模型需要学会语言本身的规律。

SFT 阶段完全不同。一条训练数据长这样:

[Prompt]  请帮我写一首关于春天的五言绝句。
[Response] 春风拂柳绿,细雨润花红。燕舞莺歌处,人间万象融。

我们的目标是:模型看到 Prompt 后,能输出 Response。我们只关心 Response 的质量,不关心 Prompt 长什么样。

如果你把 Prompt 和 Response 一起送进 CrossEntropyLoss,模型会干嘛?它会努力去"背诵"人类的提问方式——“请帮我”、“写一首”、“关于春天的”——这些 token 的预测也会产生梯度,白白浪费算力去学一些对回答问题毫无帮助的东西。

1.2 解决方案就在一个参数里

PyTorch 的 CrossEntropyLoss 有一个神仙参数叫 ignore_index,默认值是 -100。任何 labels 中值为 -100 的位置,都不会产生梯度,也不会计入 Loss 统计。

所以 SFT 训练的核心 trick 就是一句话:把 labels 中属于 Prompt 的部分全部设成 -100

输入:  [请, 帮我, 写, 一首, 诗, 春风, 拂柳, 绿, ...]
       |--- Prompt ---| |------ Response ------|

labels: [-100, -100, -100, -100, -100, 春风, 拂柳, 绿, ...]
         ↑ 全 mask,不产生梯度              ↑ 保留原样,真正的监督信号

2. Shift 错位对齐:自回归模型的一个"坑"

2.1 为什么需要错位?

自回归语言模型的预测逻辑是:

用 token_0, token_1, ..., token_t  →  预测 token_{t+1}

也就是模型的输出 logits[t] 预测的是第 t+1 个位置的 token。但 labels 序列的第 t 个位置存的是 token_t 本身——这就错位了。

打个比方:老师给你看前 3 个字,让你猜第 4 个字。答案应该是第 4 个字,而不是第 3 个字。

2.2 怎么对齐?

shift_logits = logits[..., :-1, :]    # 丢掉最后一个位置的预测
shift_labels = labels[..., 1:]        # 丢掉第一个位置的标签

对齐之后:

Shift 前:   logits[t] → 预测 labels[t]   ❌ 自己预测自己(复制)
Shift 后:   logits[t] → 预测 labels[t+1] ✅ 前文预测后文(生成)

3. 代码实现:从零手写 SFT 训练核心

3.1 数据构造:一行代码的讲究

def build_sft_data(prompt_ids: list[int], response_ids: list[int],
                   pad_id: int = 0, max_len: int = 16):
    # Step 1: 拼接
    input_ids = prompt_ids + response_ids

    # Step 2: 构造 labels — 核心就这一行
    labels = [-100] * len(prompt_ids) + response_ids

    # Step 3: 截断与填充
    if len(input_ids) > max_len:
        input_ids = input_ids[:max_len]
        labels = labels[:max_len]
    else:
        pad_len = max_len - len(input_ids)
        input_ids = input_ids + [pad_id] * pad_len
        labels = labels + [-100] * pad_len      # padding 也要 mask!

    return torch.tensor(input_ids, dtype=torch.long), torch.tensor(labels, dtype=torch.long)

三个关键点:

  1. [-100] * len(prompt_ids):Prompt 全部 mask,一个不剩。
  2. 截断时同时截 labels:否则 input_ids 和 labels 长度不一致,CrossEntropyLoss 直接报错。
  3. Padding 也填 -100:Padding token(通常是 0 或 <pad>)不是真实数据,模型不应该学它。如果忘了这行,模型会花大量算力去预测填充符。

3.2 Loss 计算:四个操作一气呵成

def compute_sft_loss(logits: torch.Tensor, labels: torch.Tensor):
    # logits:    [B, seq_len, vocab_size]
    # labels:    [B, seq_len]

    # Step 1: Shift 错位
    shift_logits = logits[..., :-1, :].contiguous()   # [B, seq_len-1, V]
    shift_labels = labels[..., 1:].contiguous()        # [B, seq_len-1]

    # Step 2: 展平 + CrossEntropyLoss
    loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
    shift_logits = shift_logits.view(-1, shift_logits.size(-1))  # [B*S, V]
    shift_labels = shift_labels.view(-1)                          # [B*S]
    loss = loss_fct(shift_logits, shift_labels)

    return loss

每一步的解释:

  • logits[..., :-1, :]:切掉最后一个预测位置。因为序列的最后一个 token 没有"下一个 token"作为监督信号。
  • labels[..., 1:]:切掉第一个标签 token。因为第一个位置的 token 没有对应的"上一个预测"。
  • .contiguous():切片操作可能导致 tensor 在内存中不连续,后续 .view() 要求连续内存。没有这一行,PyTorch 会报 "view size is not compatible" 的错——这是新手最常见的报错之一。
  • ignore_index=-100:PyTorch 底层会跳过所有 labels 中值为 -100 的位置,不计入前向 Loss 也不参与反向梯度。这比手动构造 mask 矩阵高效得多,因为是在 CUDA kernel 层面做的跳过。

4. 一张表看清 SFT 的标签流向

以一个具体例子走一遍完整流程:

prompt_ids  = [10, 20, 30]           # 3 个 token 的提问
response_ids = [40, 50, 60, 70]      # 4 个 token 的回答
max_len = 8, pad_id = 0
步骤input_idslabels说明
拼接[10,20,30,40,50,60,70][-100,-100,-100,40,50,60,70]Prompt 全 mask
填充 (+1)[10,20,30,40,50,60,70,0][-100,-100,-100,40,50,60,70,-100]Pad 也 mask
Shift logits[:-1] → 7 个位置丢掉最后一个预测
Shift labels[1:] → 7 个位置丢掉第一个标签
有效位置共 7 个位置,只有 4 个有效Prompt(3) + Pad(1) 被 mask

最终 CrossEntropyLoss 只计算 4 个 Response token 的 Loss,Prompt 和 Padding 完全不影响训练。


5. 与工业实现对照

框架实现方式差异
HuggingFace TrainerDataCollatorForCompletionOnlyLM 自动查找 Response 起始位置并设 -100自动化程度高,但不透明
LLaMA-Factorypreprocess.py 中手动构造 labels和本节实现几乎一模一样
DeepSpeed-Chatdata_utils.py 中使用相同的 mask 逻辑大规模分布式训练的首选方案

工业界的共识就是这三步:拼接 → mask prompt → shift 对齐。没有银弹,没有黑魔法。


6. 踩坑记录

6.1 忘了 shift 导致 Loss 偏高

这是 SFT 训练中最常见的 bug。如果不做 labels[..., 1:],相当于让 logits[t] 预测 labels[t]——模型的任务从"预测下一个 token"变成了"复制当前 token"。模型当然学得会(复制嘛),但推理时生成出来的东西就是一坨随机噪声,因为推理时根本不知道当前 token 是什么。

检查方法:打印一个 batch 的 shift_labels 中非 -100 的 token 数量,和 Response 长度对比。

6.2 Padding 的 labels 没设 -100

模型输出的 token 分布中 padding token(通常是 0)的频率异常高。因为模型在努力学"在句子末尾输出 "——这显然不是你想要的。

6.3 截断剪掉了所有 Response

max_len 设得太小,截断后 Prompt 还在但 Response 全被剪掉了。此时 labels 全是 -100,Loss = 0,但模型什么都没有学到。训练曲线看着很美(Loss=0),实际上是空跑。

检查方法:确保 max_len > len(prompt) + min_response_len


7. 延伸思考

  • Chat Template 和本节的关系:真实 SFT 数据还需要插入特殊 token(LLaMA 的 <|begin_of_text|><|start_header_id|> 等)。这些特殊 token 也应该 mask 掉,因为它们属于"格式"而非"内容"。但 chat template 的处理是数据预处理层的事,本节的核心 label mask 逻辑不变。
  • Packing 技术:为提高 GPU 利用率,可以把多条短数据拼成一条长序列(packing)。此时 label mask 变得复杂:不仅要 mask prompt,还要 mask 不同样本之间的 padding 和 cross-sample attention。
  • 与 RLHF 的关系:SFT 是整个 RLHF 管线(SFT → Reward Model → PPO)的第一步。本节的数据构造逻辑直接沿用给后续的 PPO 阶段(PPO 中也需要 mask prompt 来避免策略在无关 token 上乱改)。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值