本章定位:从表示学习(Ch1–Ch4)跨入生成模型对齐(Ch5–Ch8)的第一站。SFT 是把”只会预测下一个词”的 Base Model 转变为”能听懂指令”的 Assistant 的关键步骤;LoRA/QLoRA/DoRA 是当今所有微调任务的工程标配。
承上:Ch1 §6 的交叉熵 = SFT 的损失函数本身。
启下:Ch6 起的所有 RL 对齐方法都以 SFT 模型为初始化。
§A 数学原理
1. SFT 的本质:最大似然估计 (MLE)
给定数据集 $D_{\text{SFT}} = {(x_i, y_i)}$,其中 $x$ 是 prompt、$y$ 是 response,目标是最大化条件概率 $P_\theta(y \mid x)$:
由于 $y = (y_1, y_2, \dots, y_T)$ 是 token 序列,自回归模型把它分解为:
取负对数得到 SFT 的损失:
这就是 Ch1 §6 的交叉熵:让模型分布逼近”标注员的分布”。
2. 为什么只对 Response 计算 Loss?
SFT 时,整条序列 [Prompt][Response] 都喂给模型,但Loss Mask 让 prompt 部分不参与 loss 计算:
原因:
- 我们不希望模型学习”用户怎么提问”(prompt 是外部输入)
- 只要求模型学习”给定 prompt 如何生成正确 response”
实现方式:构造 labels 张量,prompt 位置设为 -100(PyTorch 默认 ignore_index,自动跳过)。
3. LoRA:低秩适配的数学
LoRA(Low-Rank Adaptation, Microsoft 2021)的核心假设:预训练模型在下游任务上的更新量 $\Delta W$ 是低秩的。
3.1 参数化方式
原始权重 $W_0 \in \mathbb{R}^{d \times d}$ 冻结,引入两个低秩矩阵:
- $A \in \mathbb{R}^{r \times d}$,初始化为高斯分布
- $B \in \mathbb{R}^{d \times r}$,初始化为零矩阵
更新形式:
前向传播:
3.2 为什么 $B$ 初始化为零?
如果 $A, B$ 都随机初始化,训练初期 $BA \neq 0$,模型从一开始就偏离了预训练。$B = 0$ 保证 $BAx = 0$,模型初始状态 = 预训练状态,训练从一个已知的好起点出发。
3.3 参数量节省
原始权重参数量:$d \times d = d^2$
LoRA 参数量:$d \times r + r \times d = 2dr$
设 $d = 4096, r = 8$:
LoRA 把可训练参数压缩到原始的 1% 以下。
3.4 缩放因子 $\alpha$
实际的 LoRA 公式还有一个缩放:
$\alpha$ 通常设为 $r$ 的 2–4 倍(如 $r=8, \alpha=32$)。作用:当增大 $r$ 时,$\frac{\alpha}{r}$ 自动缩小,避免重新调整学习率。
4. QLoRA:4-bit 量化 + LoRA
QLoRA(Dettmers 2023)让 65B 模型可以在单张 48GB 显卡上微调:
| 技术 | 作用 |
|---|---|
| 4-bit NF4 量化 | Base model 权重压缩到 4-bit(NormalFloat 4,对正态分布权重最优) |
| Double Quantization | 量化常数本身再量化,节省 ~0.4 bit/参数 |
| Paged Optimizer | 用 NVIDIA 统一内存,optimizer state 临时换出到 CPU |
| LoRA on top | 量化的 base model 不动,只训练 16-bit 的 LoRA 矩阵 |
核心数学:base model 用 4-bit 存储但前向传播时反量化为 BF16 计算,loss 和梯度都是 BF16,所以训练精度几乎不损失。
5. DoRA:Weight Decomposed LoRA
DoRA(NVIDIA 2024)把权重分解为”幅值 + 方向”:
其中 $m$ 是幅值(每列一个标量),$V/|V|$ 是方向。LoRA 只更新方向部分 $V$,幅值 $m$ 全量训练。
直觉:预训练模型已经学到了大致的”权重幅值分布”,下游任务主要在调整”方向”。DoRA 在多个任务上比 LoRA 高 1–2 个点。
§B 模型结构(PyTorch 实现)
B.1 SFT 数据处理:构造 labels = -100 的 mask
这是 SFT 工程上最容易写错的地方。
1 | def build_sft_labels(input_ids, prompt_length): |
TRL 库的现成实现:1
2
3
4
5
6from trl import DataCollatorForCompletionOnlyLM
collator = DataCollatorForCompletionOnlyLM(
response_template="### Response:", # 标志 response 开始
tokenizer=tokenizer,
)
B.2 LoRA 的 PyTorch 手写实现
1 | import torch |
B.3 PEFT 库的工业级使用
实际工程中用 PEFT 库就够了:
1 | from peft import LoraConfig, get_peft_model, TaskType |
target_modules 的演进:
- 早期 LoRA 论文:只加在 $W_Q, W_V$
- 现代趋势:全层加 LoRA(包括 FFN 和输出投影),效果显著提升
B.4 推理时合并 LoRA 权重
LoRA 训练完后,推理时可以把 $BA$ 合并回 $W_0$,零额外开销:
1 | # PEFT 一键合并 |
合并后模型与原始模型结构完全一致,推理时无任何性能损失——这是 LoRA 比 Adapter(需推理时额外计算)优越的关键。
§C 训练与推理
C.1 训练流程:完整的 SFT 训练循环
1 | from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer |
C.2 显存对比:Full FT vs LoRA vs QLoRA
以 Llama-3 8B 为例(粗略估算):
| 方式 | 模型权重 | 梯度 | Adam state | 总计 (~) |
|---|---|---|---|---|
| Full FT (BF16) | 16 GB | 16 GB | 32 GB (FP32) | 64+ GB |
| LoRA (BF16) | 16 GB | 0.08 GB | 0.16 GB | ~17 GB |
| QLoRA (4-bit + LoRA) | 4 GB | 0.08 GB | 0.16 GB | ~5 GB |
QLoRA 让单张 24GB 卡能微调 65B 模型,单张 48GB 卡能微调 70B 模型。
C.3 SFT 解决了什么 / 还差什么
解决了:
- ✅ 格式对齐:模型学会”问答”形式
- ✅ 能力激活:预训练知识被”提取”出来用于回答
还差什么:
- ❌ 多解性:写代码、写诗有多种正确答案,SFT 强迫模型只学某一种
- ❌ 错误积累:本质是模仿学习,标注里的错误会被学得很扎实
- ❌ 没学到”什么是不好的”:SFT 只展示正例,没有负例对比
→ 这就是为什么需要后续的 RLHF/DPO(Ch6, Ch7)。
C.4 SFT 后的推理:解码策略
SFT 训出的模型在推理时和 base model 形式上没有区别,都是自回归生成。但因为模型分布变得更”尖锐”(向标注者的语气和格式聚集),常用的解码策略需要调整:
| 策略 | 公式/做法 | 适用场景 |
|---|---|---|
| Greedy | $\arg\max_y P(y \mid \cdot)$ | 确定性任务(代码、数学) |
| Top-k | 只从概率最高的 $k$ 个 token 中采样 | 一般生成 |
| Top-p (nucleus) | 从累积概率 $\geq p$ 的最小集合中采样 | 创意生成、对话 |
| Temperature $T$ | $\text{softmax}(\text{logits} / T)$,$T < 1$ 更尖锐 | 控制确定性 |
| Repetition Penalty | 已生成 token 的 logits 减去惩罚 | 避免循环输出 |
常见组合:对话默认
temperature=0.7, top_p=0.9;代码任务用temperature=0.0(即 greedy)。
§D 章末速查
D.1 LoRA 的关键 Q&A
Q1:LoRA 训练多少参数?
- 取决于 r 和 target_modules。典型 7B 模型:
- 只加 q,v:~4M(0.05%)
- 加全 7 个 module:~40M(0.5%)
Q2:LoRA 学习率多大?
- 通常 2e-4 到 5e-4,比全量微调(2e-5)大 10 倍。
- 因为只训练少量参数,需要更大的步长。
Q3:LoRA 推理时有额外开销吗?
- 训练时:有(额外 BA 计算)
- 推理时:合并后零开销,与原模型完全一致
Q4:LoRA 合并后能再继续训练吗?
- 可以。但通常做法是保存多个 LoRA adapter,按需加载,不合并。这样一个 base model 可以服务多个任务。
D.2 SFT 数据规模与质量
| 数据规模 | 适用场景 | 代表 |
|---|---|---|
| 1k–10k 高质量 | 概念验证、风格定制 | LIMA(”少即是多”) |
| 10k–100k | 通用 SFT | InstructGPT |
| 100k–1M | 全方位强化 | Llama 3 等 |
| >1M | 边际收益递减 | 商业大模型 |
LIMA 论文核心结论:1000 条精心筛选的 SFT 数据,能让 65B 模型达到 GPT-4 的对话质量。质量远比数量重要。
承上启下
SFT 让模型学会了”听话”,但只学到了”标注者怎么写”,无法理解”什么是更好的”。下一章 Ch6 引入奖励模型(RM)和 PPO,让模型从人类偏好中学习——这是从 GPT-3.5 到 ChatGPT 的关键一跳。
Ch6 中的 SFT 模型在三个地方用到:
- Policy 初始化:SFT 模型直接作为 PPO 的初始 Policy
- Reference Model:SFT 模型的冻结副本(呼应 Ch4 §D 的 BYOL Target)
- RM 初始化:RM 也用 SFT 模型改造(去掉 LM Head,加 Scalar Head)