1.FlashAttention 整体结构
FlashAttention 是由斯坦福大学 Tri Dao 等人在 2022 年提出的一种精确(非近似)的注意力计算算法。它通过重新设计注意力的计算方式,大幅降低了显存占用并加速了训练与推理,现已成为现代大模型(如 GPT、LLaMA 等)的标配。
一、问题背景:标准 Attention 的瓶颈
在 Transformer 中,自注意力的核心计算是:
其中 $Q, K, V \in \mathbb{R}^{N \times d}$,$N$ 是序列长度,$d$ 是每个头的维度。
标准实现的做法(PyTorch 默认的 naive 实现):
- 计算 $S = QK^T$,得到一个 $N \times N$ 的矩阵,写回 HBM(显存)
- 计算 $P = \text{softmax}(S)$,再次写回 HBM
- 计算 $O = PV$,写回 HBM
两个核心问题:
- 显存占用 $O(N^2)$:当 $N=8192$ 时,单个注意力矩阵就要 256MB(fp32),多头多层叠加显存爆炸
- 速度瓶颈不是算力,而是显存带宽:GPU 的算力增长远快于显存带宽,注意力计算是典型的 memory-bound(访存受限)问题,大量时间花在 HBM 读写上
二、GPU 内存层级的关键认知
理解 FlashAttention 必须先理解 GPU 的内存结构:
| 层级 | 容量 | 带宽 | 速度 |
|---|---|---|---|
| HBM(全局显存) | 40–80 GB | ~1.5–3 TB/s | 慢 |
| SRAM(片上共享内存) | ~20 MB(A100 全部 SM 合计) | ~19 TB/s | 快约 10 倍 |
标准 attention 反复在 HBM 中读写大矩阵,SRAM 几乎没被利用。FlashAttention 的核心思路就是:尽可能把计算留在 SRAM 里,避免往 HBM 写中间结果。
三、核心思想:Tiling + Online Softmax + Recomputation
1. Tiling(分块计算)
把 $Q, K, V$ 沿序列维度切成小块(block),每次只把一小块加载到 SRAM 中计算,这样 $N \times N$ 的大矩阵永远不会被完整地实例化。
假设将 $Q$ 切成 $T_r$ 块,$K, V$ 切成 $T_c$ 块,算法变成一个双层循环: 外层遍历 $Q$ 的块,内层遍历 $K, V$ 的块。
2. Online Softmax(增量式 softmax)
最大的技术难点是: softmax 需要看到整行所有元素才能归一化(要算分母 $\sum e^{x_i}$),而分块计算时一次只能看到一部分。FlashAttention 采用了 Milakov & Gimelshein 提出的 online softmax 技巧。
为了数值稳定,softmax 通常减去最大值:
当新的一块数据进来,当前块最大值是 $m^{new}$,则可以这样更新:
输出 $O$ 也可以用类似的 rescaling 进行增量更新:
这样每处理完一个 $K, V$ 块,就用缩放因子修正之前累积的结果,数学上与标准 softmax 完全等价,精度没有任何损失。
3. Recomputation(反向传播时重算)
在反向传播中,原本需要保存 $N \times N$ 的注意力矩阵 $P$ 来算梯度。FlashAttention 只保存 softmax 的统计量(每行的最大值 $m$ 和归一化因子 $\ell$,各 $O(N)$),反向时重新计算 $S$ 和 $P$。
看起来增加了计算量,但由于 attention 是 memory-bound 的,省下来的 HBM 访问时间远多于重算的时间,整体反而更快。
四、复杂度分析
设 $M$ 是 SRAM 大小:
| 指标 | 标准 Attention | FlashAttention |
|---|---|---|
| 计算量(FLOPs) | $O(N^2 d)$ | $O(N^2 d)$ |
| HBM 访问量 | $O(Nd + N^2)$ | $O(N^2 d^2 / M)$ |
| 显存占用 | $O(N^2)$ | $O(N)$ |
计算量没变,但 HBM 访问量在 $d \ll M$ 的常见情况下大幅降低($d$ 一般是 64 或 128,$M$ 在 A100 上约 100KB 级别),这是加速的根本来源。
2.Online Softmax
Online Softmax 是 FlashAttention 能成立的数学基石。它由 NVIDIA 的 Milakov & Gimelshein 在 2018 年的论文 “Online normalizer calculation for softmax” 中提出,核心思想是:在只看过部分数据的情况下,增量地计算数值稳定的 softmax,并在新数据到来时做修正。
下面我会从标准 softmax 的问题开始,一步步推导到 online 版本,并解释它在 FlashAttention 中的角色。
一、标准 Softmax 与数值稳定性问题
1.1 朴素定义
给定向量 $x = (x_1, x_2, \dots, x_N)$,softmax 定义为:
1.2 数值溢出问题
如果某个 $x_i$ 较大(比如 1000),$e^{x_i}$ 会溢出变成 inf。即使在 fp16 下,$x_i > 11$ 就会溢出($e^{11} \approx 59874$,fp16 的上限约 65504)。
1.3 Safe Softmax(3-pass 版本)
解决方法是减去最大值:
这个变换数学上完全等价(分子分母同时乘以 $e^{-m}$),但把最大指数限制在 0,保证 $e^{x_i - m} \in (0, 1]$,不会溢出。
算法变成 3 遍扫描:
1 | Pass 1: 求最大值 m |
问题:需要 3 次读取整个 $x$ 向量,对 GPU 这种 memory-bound 场景非常不友好。如果 $x$ 太大放不进 SRAM,就得反复从 HBM 读 3 次。
二、Online Softmax 的核心想法
目标:把 Pass 1 和 Pass 2 合并成一次扫描,也就是边扫描边同时维护 $m$ 和 $\ell$。
难点在于:$\ell$ 的定义依赖 $m$,而 $m$ 要扫完整个数组才能确定。扫到一半时的”局部最大值”可能比真正的最大值小,之前算出来的 $\ell$ 就是错的。
关键洞察:如果有办法在 $m$ 发生变化时修正 $\ell$,就可以一边扫一边算。
2.1 推导修正公式
假设已经处理了前 $i$ 个元素,维护着:
- $m_i$ = 前 $i$ 个元素中的最大值
- $\ell_i = \sum_{j=1}^{i} e^{x_j - m_i}$ = 基于当前最大值的归一化因子
现在来了第 $i+1$ 个元素 $x_{i+1}$,新的最大值是:
要得到新的 $\ell_{i+1} = \sum_{j=1}^{i+1} e^{x_j - m_{i+1}}$,观察:
对旧部分做一个简单变换:
于是得到更新公式:
几何直觉: 每次最大值增大(即 $m_{i+1} > m_i$),之前累积的 $\ell_i$ 是基于”旧基准” $m_i$ 的,现在基准变大了,所有旧项都得乘以 $e^{m_i - m_{i+1}} < 1$ 进行缩小(rescale)。如果最大值没变($m_{i+1} = m_i$),缩放因子是 $e^0 = 1$,什么都不用调整。
2.2 2-pass 算法
1 | Pass 1: 同时计算 m 和 ℓ |
从 3 遍扫描降到了 2 遍,而且数值稳定性与 safe softmax 完全相同(始终保持减去当前最大值)。
三、分块(Block-wise)版本:FlashAttention 的真正需求
单元素的更新公式只是开胃菜。FlashAttention 需要按块处理——一次性算完一个块的局部 $m$ 和 $\ell$,再和已有的累积量合并。
3.1 两块合并公式
设已经处理完块 A,维护着 $(m_A, \ell_A)$; 现在处理块 B,独立算出 $(m_B, \ell_B)$,其中:
合并后的全局量 $(m, \ell)$ 应为:
这就是 FlashAttention 里那个看起来眼熟的公式。推导完全同上:两块都要把基准从自己的局部最大值改到全局最大值,各自乘以对应的 $e^{m_{\text{local}} - m_{\text{global}}}$ 缩放因子。
注意这个合并操作是满足结合律的——合并 A、B 再合并 C,和先合并 B、C 再和 A 合并,结果一致。这使得它能很自然地并行化(比如 tree reduction)。
四、扩展:同时在线更新输出 $O$
在普通 softmax 里,Pass 2 要重新扫一遍算 $y_i$。但在 attention 里,我们最终要的不是 softmax 本身,而是 $O = PV$,其中 $P = \text{softmax}(S)$。FlashAttention 需要把 $O$ 也一起在线累积,这样整个算法变成 1-pass。
4.1 定义未归一化的输出
定义”未归一化”的输出(只做了减最大值但没除以 $\ell$):
处理块 B 时同理:
合并后全局未归一化输出应为 $\tilde{O} = \sum_{j \in A \cup B} e^{x_j - m} V_j$,同样需要基准变换:
最后一步再除以 $\ell$ 得到真正的 $O = \tilde{O} / \ell$。
4.2 伪代码(FlashAttention 核心)
1 | m = -∞, ℓ = 0, O = 0 # O 此处表示未归一化的 Õ |
这样,整个 attention 只需一次扫描 $K, V$,完全不需要实例化 $N \times N$ 的注意力矩阵。FlashAttention-2 的一个小优化就是把每一步的 / ℓ 推迟到最后,避免反复做除法。
五、为什么数学上精确?
这点很多人初学时会怀疑:”边扫边更新,最后结果真的和一次性算一样吗?”
答案是完全一样,因为每次合并时:
- $m$ 始终是当前已见元素的真实最大值
- $\ell$ 通过缩放因子修正后,始终等于 $\sum_{j \in \text{已见}} e^{x_j - m_{\text{当前}}}$
- $\tilde{O}$ 同理
扫描结束时,$m$ 等于全局最大值,$\ell$ 等于全局归一化因子,$\tilde{O} / \ell$ 等于标准 attention 的输出。这和”提前知道 $m$ 再一口气算”在代数上完全等价——没有任何近似。
六、一个数值例子
以 $x = [1, 3, 2, 5]$ 为例,手工走一遍:
标准做法: $m = 5$,$\ell = e^{-4} + e^{-2} + e^{-3} + e^0 \approx 0.0183 + 0.1353 + 0.0498 + 1 = 1.2034$
Online 做法:
| 步骤 | 新元素 | $m_{\text{new}}$ | 缩放因子 $e^{m_{\text{old}} - m_{\text{new}}}$ | $\ell$ 更新 |
|---|---|---|---|---|
| 初始 | — | $-\infty$ | — | $\ell = 0$ |
| 步 1 | 1 | 1 | $e^{-\infty - 1} = 0$ | $\ell = 0 \cdot 0 + e^0 = 1$ |
| 步 2 | 3 | 3 | $e^{1 - 3} = 0.1353$ | $\ell = 1 \cdot 0.1353 + e^0 = 1.1353$ |
| 步 3 | 2 | 3 | $e^{3 - 3} = 1$ | $\ell = 1.1353 \cdot 1 + e^{-1} = 1.1353 + 0.3679 = 1.5032$ |
| 步 4 | 5 | 5 | $e^{3 - 5} = 0.1353$ | $\ell = 1.5032 \cdot 0.1353 + e^0 = 0.2034 + 1 = 1.2034$ ✓ |
最后得到 $\ell = 1.2034$,与标准做法完全一致。
七、工程意义
Online softmax 看似只是一个小技巧,实际上解锁了一整类算法:
- FlashAttention — 让 attention 从 $O(N^2)$ 显存降到 $O(N)$
- Ring Attention — 把序列分片到多 GPU,每个 GPU 持一部分 $K, V$,通过环形通信传递部分结果,用 online softmax 合并。支持百万级 token 序列
- 流式推理 / Prefix Caching — 先算一部分 prefix 的 softmax 状态,后续 token 来时增量更新
- 分布式 softmax — 跨设备的 softmax 归一化
它的思想可以抽象为:任何 “reduce + 归一化”的操作,只要能找到带修正因子的合并律,就能分块并行计算。这一点也和 log-sum-exp 的并行化密切相关(实际上 $\log \ell + m$ 就是 log-sum-exp,online softmax 等价于 log-sum-exp 的在线计算)。
3.Recomputation(重计算)详解
Recomputation 是 FlashAttention 反向传播的核心技巧。它的本质是一个用计算换显存的权衡——前向时故意不保存中间结果,反向时重新算一遍。这个思想本身不是 FlashAttention 发明的(在深度学习中叫 gradient checkpointing,2016 年就有了),但 FlashAttention 把它用到了极致,并且在这里不但不慢,反而更快。
一、为什么反向传播需要中间结果?
先回顾反向传播的基本原理:链式法则需要”前向时的中间激活值”。
举个最简单的例子,$y = \sigma(Wx)$,反向求 $\frac{\partial L}{\partial W}$:
这里需要 $\sigma’(Wx)$,而 $\sigma’$ 依赖前向的中间量 $Wx$。所以默认情况下,所有中间激活值都得保存到反向用完为止。
二、标准 Attention 反向需要什么?
前向是:
反向时给定 $\frac{\partial L}{\partial O} = dO$,要求 $dQ, dK, dV$。推导一下(详细推导步骤请看附录A“反向传播公式推导”):
需要的中间量:
- $P$(用于算 $dV$ 和 $dS$)—— $N \times N$ 矩阵
- $S$ 本身不一定要,因为 $P$ 已经隐含了 $S$ 的信息
所以标准实现要保存整个 $N \times N$ 的 $P$ 矩阵,反向期间一直占着显存。这又把我们带回了 $O(N^2)$。
三、FlashAttention 的选择:只保存最小必要信息
FlashAttention 前向结束后,只保留两样东西用于反向:
| 保留的量 | 形状 | 含义 |
|---|---|---|
| $O$ | $N \times d$ | 前向的输出(本来就要保留) |
| $L$ | $N \times 1$ | 每行的 “logsumexp” 统计量 |
其中 $L$ 是每行的:
$m_i$ 和 $\ell_i$ 就是上一节 online softmax 里的最大值和归一化因子。两个标量被合并成了一个(后面会解释为什么这样更方便)。
总共额外需要的显存:$O(N) + O(Nd) = O(N)$,比原来省了 $N \times N$ 矩阵那么多。
四、反向时怎么重算?
反向传播依然是分块进行的,和前向一样按 tile 扫描。核心观察是:
给定 $Q_i, K_j, L_i$,可以完全重建某个小块的 $P_{ij}$。
4.1 重建 $P$ 的公式
对于块 $(i, j)$:
这里的关键点:$L_i = m_i + \log \ell_i$,所以
正好是标准 softmax 的结果。把 $m$ 和 $\ell$ 合并成 $L$ 的好处就在这:一次减一次 exp 就拿到归一化的 $P$,不用再做除法。
4.2 反向算法(简化版)
1 | 输入: Q, K, V, O, dO, L, 块大小 B_r, B_c |
注意里面那个 $D_i = \text{rowsum}(dO_i \odot O_i)$ 的小技巧——softmax 反向本来需要 $\text{rowsum}(dP \odot P)$,但在 attention 里有个恒等式:
这意味着这个量可以直接用 $O$ 和 $dO$ 算出来,不需要完整的 $P$ 和 $dP$。这个 $D_i$ 本来可以预先一次性算好(形状只有 $N \times 1$),让反向循环更干净。
五、计算量的账:真的能接受吗?
重算显然多了 FLOPs,算算具体多多少。
前向计算量:
- $S = QK^T$: $2N^2 d$ FLOPs
- softmax: $O(N^2)$
- $O = PV$: $2N^2 d$ FLOPs
总共约 $4N^2 d$。
反向计算量(标准实现,不重算):
- $dV = P^T dO$: $2N^2 d$
- $dP = dO V^T$: $2N^2 d$
- softmax 反向: $O(N^2)$
- $dQ = dS \cdot K$: $2N^2 d$
- $dK = dS^T Q$: $2N^2 d$
总共约 $8N^2 d$。
FlashAttention 反向(带重算):
- 多算一遍 $S = QK^T$: $+2N^2 d$
- 其他同上
总共约 $10N^2 d$。
增加了大约 25% 的 FLOPs(从 $8N^2d$ 到 $10N^2d$)。
六、为什么增加 25% FLOPs 反而更快?
这是 FlashAttention 最反直觉的地方:明明多算了,为什么还更快?
答案是之前反复强调的一点:attention 是 memory-bound,不是 compute-bound。
6.1 现代 GPU 的算力 / 带宽失衡
A100 的数字:
- FP16 算力:312 TFLOPS
- HBM 带宽:1.5 TB/s
- 算力/带宽比约 200 FLOPs/byte
这意味着:每从 HBM 读一个字节,GPU 最好能做 200 次浮点运算才”划算”。如果做不到(算术强度低于这个比值),GPU 就在等数据——算力闲置。
Attention 的瓶颈就在于 $N^2$ 的中间矩阵读写,算术强度很低。
6.2 标准反向 vs FlashAttention 反向的实际耗时
| 项目 | 标准反向 | FlashAttention 反向 |
|---|---|---|
| FLOPs | $8N^2d$ | $10N^2d$(多 25%) |
| HBM 读写 | $O(N^2)$(读写 $P, dP, dS$ 等大矩阵) | $O(N^2 d^2 / M)$(只读写 $Q, K, V, O, dO, dQ, dK, dV$ 和标量 $L, D$) |
| 实际瓶颈 | HBM 带宽 | 算力 |
因为 HBM 读写从 $O(N^2)$ 降到接近 $O(Nd)$(忽略 block 层面的重复访问),省掉的 HBM 访问时间 远多于 多算 25% FLOPs 的时间。
举个粗略的估算:若 $N = 4096, d = 64$,标准反向要读写的 $P, dP$ 等大约是 $4 \times N^2 \times 2 = 128$ MB,在 1.5 TB/s 的带宽下约耗时 85 μs。而多算的 $2N^2 d$ FLOPs 在 312 TFLOPS 下只需 0.2 μs。两个数量级的差距。
七、和通用 Gradient Checkpointing 的关系
通用的 gradient checkpointing(Chen et al. 2016)是:
- 前向只在某些”checkpoint”保存激活
- 反向时从最近的 checkpoint 重跑一段前向
FlashAttention 的 recomputation 可以看作是针对 attention 的专用 checkpointing:
- checkpoint 的是什么?$L$(softmax 统计量)和 $O$
- 重算的是什么?$S$ 和 $P$ 这两个大矩阵
但 FlashAttention 的做法比通用 checkpointing 更精细:
- 融合 kernel:重算和梯度计算在同一个 CUDA kernel 里完成,中间值只活在 SRAM,不走 HBM
- 按需重算,粒度是块:不是重跑整段前向,而是反向循环遍历到哪个 block 就重算哪个 block
- 倒过来快:通用 checkpointing 只是节省显存、会更慢;FlashAttention 因为跳过了 HBM 读写,同时更省显存并且更快
这是一个理论上罕见的”两全其美”——它得以成立,完全依赖于现代 GPU 算力 / 带宽的失衡。
八、总结
Recomputation 在 FlashAttention 中解决的核心问题是:反向传播怎么不保存 $N \times N$ 的 $P$ 矩阵。
关键要点:
- 只存 $L = m + \log \ell$ 和 $O$(都是 $O(N)$ 或 $O(Nd)$),不存 $S, P$
- 反向时用 $Q_i, K_j, L_i$ 在 SRAM 里重算 $S_{ij}$ 和 $P_{ij}$,用完即丢
- 利用 $\text{rowsum}(dP \odot P) = \text{rowsum}(dO \odot O)$ 这个恒等式避免需要完整的 $dP$
- 多了约 25% 的 FLOPs,但因为 attention 是 memory-bound,省下的 HBM 读写时间远超重算的开销
更高层的启示是:在现代 GPU 上,“少存一点、多算一点” 经常是更好的选择。算力在持续增长,带宽却增长缓慢,这个剪刀差会让 recomputation 这种技术越来越普遍。
4. 附录A
结论一:若 $Y = AB$,则 $dA = dY \cdot B^T, \quad dB = A^T \cdot dY $
结论 1:若 $Y = AB$,则
这是矩阵求导里最基础也最重要的结论。我之前给的推导比较简略,这里从头一步步讲清楚,包括直觉、严格推导、和几种不同的验证方法。
一、先把问题说清楚
设定:
- $A \in \mathbb{R}^{m \times k}$
- $B \in \mathbb{R}^{k \times n}$
- $Y = AB \in \mathbb{R}^{m \times n}$
- $L$ 是某个标量损失函数,依赖 $Y$(通过 $Y$ 间接依赖 $A$ 和 $B$)
已知:$dY := \frac{\partial L}{\partial Y} \in \mathbb{R}^{m \times n}$(上游传下来的梯度,形状和 $Y$ 一样)
目标:求 $dA = \frac{\partial L}{\partial A} \in \mathbb{R}^{m \times k}$ 和 $dB = \frac{\partial L}{\partial B} \in \mathbb{R}^{k \times n}$
关键约定:梯度 $\frac{\partial L}{\partial X}$ 的形状总是和 $X$ 本身一样。也就是说,$dA$ 的第 $(i,k)$ 元素表示 $\frac{\partial L}{\partial A_{ik}}$。这个约定叫 "denominator layout",是深度学习的通用惯例。
二、标量形式的矩阵乘法
先把矩阵乘法写成逐元素的形式:
即 $Y$ 的第 $(i,j)$ 个元素是 $A$ 的第 $i$ 行和 $B$ 的第 $j$ 列的点积。
三、推导 $dA$
3.1 应用多变量链式法则
$L$ 依赖 $A$ 是通过所有 $Y_{ij}$ 实现的,所以:
这里 $i', j$ 遍历 $Y$ 的所有元素,因为改变 $A_{ik}$ 可能影响 $Y$ 的每一个元素,所以要全部求和。
3.2 计算 $\frac{\partial Y_{i'j}}{\partial A_{ik}}$
从 $Y_{i'j} = \sum_p A_{i'p} B_{pj}$ 出发。
$A_{ik}$ 是否出现在 $Y_{i'j}$ 的表达式里? 分情况讨论:
- 如果 $i' \neq i$:表达式里是 $A_{i'p}$ 的各种形式,第一个下标是 $i'$,而我们关心的 $A_{ik}$ 第一个下标是 $i$,不同。根本不出现,所以偏导为 0。
- 如果 $i' = i$:$Y_{ij} = \sum_p A_{ip} B_{pj} = A_{i1}B_{1j} + A_{i2}B_{2j} + \dots + A_{ik}B_{kj} + \dots$。其中只有 $p = k$ 那一项含 $A_{ik}$,系数是 $B_{kj}$。
所以:
3.3 代回链式法则
把这个结果代入:
由于指示函数 $\mathbb{1}[i' = i]$ 只在 $i' = i$ 时为 1,外层对 $i'$ 的求和塌缩,只剩下 $i' = i$ 这一项:
3.4 识别这是一个矩阵乘法
看这个式子的结构:$\sum_j dY_{ij} B_{kj}$。
标准矩阵乘法 $(XZ)_{ik} = \sum_j X_{ij} Z_{jk}$,需要内部下标匹配。我们现在的 $B_{kj}$ 内部下标是 $j$,但 $k$ 在前面,所以要把 $B$ 转置一下:$B^T_{jk} = B_{kj}$。
于是:
写成矩阵形式:
形状验证:$dY \in \mathbb{R}^{m \times n}$,$B^T \in \mathbb{R}^{n \times k}$,乘积形状是 $m \times k$,与 $A$ 一致 ✓
四、推导 $dB$
完全类似的过程。
4.1 链式法则
4.2 计算 $\frac{\partial Y_{ij'}}{\partial B_{kj}}$
从 $Y_{ij'} = \sum_p A_{ip} B_{pj'}$ 出发。
- 如果 $j' \neq j$:$B_{pj'}$ 第二个下标是 $j'$,不含 $B_{kj}$,偏导为 0。
- 如果 $j' = j$:$Y_{ij} = \sum_p A_{ip} B_{pj}$,其中 $p = k$ 的项是 $A_{ik} B_{kj}$,其他项无关。
所以:
4.3 代回
4.4 识别为矩阵乘法
现在要组成一个结果形状为 $k \times n$ 的矩阵,下标是 $(k,j)$。式子里是 $\sum_i A_{ik} dY_{ij}$。
$A_{ik} = A^T_{ki}$,所以:
写成矩阵形式:
形状验证:$A^T \in \mathbb{R}^{k \times m}$,$dY \in \mathbb{R}^{m \times n}$,乘积形状 $k \times n$,与 $B$ 一致 ✓
五、直觉理解:为什么一定是转置?
光有严格推导还不够,这里讲两种直觉,帮助你以后不看推导也能秒写出来。
5.1 形状匹配法(最快)
已知 $Y = AB$,形状 $(m,n) = (m,k) \cdot (k,n)$。
想求 $dA$,形状必须是 $(m,k)$。能用到的两块积木是 $dY$(形状 $m \times n$)和 $B$(形状 $k \times n$)。
要得到 $m \times k$,唯一合理的方式是 $(m \times n) \cdot (n \times k) = (m \times k)$,所以 $B$ 必须转置。
同理求 $dB$,形状必须 $(k,n)$,用 $A$ 和 $dY$ 拼:$(k \times m) \cdot (m \times n) = (k \times n)$,所以 $A$ 要转置。
这个技巧几乎万能。很多人做推导就是靠形状对齐。
5.2 "什么乘什么得到 Y" 视角
看 $Y = AB$:
- $A$ 在左边,所以它的梯度也要"从左边构造出来",剩下的 $B$ 就得挪到右边,但为了形状对,$B$ 要转置 → $dA = dY \cdot B^T$
- $B$ 在右边,所以它的梯度也要"从右边构造出来",$A$ 挪到左边并转置 → $dB = A^T \cdot dY$
助记:梯度的公式就是把 $Y = AB$ 里的 $Y$ 换成 $dY$,另一个因子转置并保持在原来的位置。
六、用更紧凑的方法验证:全微分
上面是硬推,这里给一种优雅的验证方法,也是研究者常用的技巧。
核心工具:对标量 $L$,有
意思是:$L$ 的微小变化 = 梯度与 $X$ 微小变化的"内积"(迹形式的 Frobenius 内积)。如果能把 $dL$ 写成 $\text{tr}(M^T dX)$,那么 $\frac{\partial L}{\partial X} = M$。
6.1 对 $A$ 求导
$Y = AB$,固定 $B$,$Y$ 的微分为 $dY = dA \cdot B$(这里 $dA$ 是 $A$ 的微小变化,不是梯度,记号冲突但上下文清楚)。
为避免混淆,换下记号:令 $G_Y = \frac{\partial L}{\partial Y}$。
利用迹的循环性 $\text{tr}(XYZ) = \text{tr}(ZXY)$:
对比 $dL = \text{tr}\left(\left(\frac{\partial L}{\partial A}\right)^T dA\right)$,读出:
6.2 对 $B$ 求导
固定 $A$,$dY = A \cdot dB$:
对比 $dL = \text{tr}\left(\left(\frac{\partial L}{\partial B}\right)^T dB\right)$:
这种方法不需要下标,几行就搞定,是研究矩阵微分的标准套路。推荐掌握。
七、一个小例子手动验证
取最小的情况:$A \in \mathbb{R}^{1 \times 2}$,$B \in \mathbb{R}^{2 \times 1}$,则 $Y = AB \in \mathbb{R}^{1 \times 1}$ 是标量。
记 $A = [a_1, a_2]$,$B = [b_1, b_2]^T$,则:
直接求偏导:
- $\frac{\partial Y}{\partial a_1} = b_1$,$\frac{\partial Y}{\partial a_2} = b_2$,所以 $dA_{\text{直接}} = [b_1, b_2] = B^T$
- $\frac{\partial Y}{\partial b_1} = a_1$,$\frac{\partial Y}{\partial b_2} = a_2$,所以 $dB_{\text{直接}} = [a_1, a_2]^T = A^T$
用公式验证(取 $dY = 1$,因为 $Y$ 就是 $L$):
- $dA = dY \cdot B^T = 1 \cdot B^T = B^T$ ✓
- $dB = A^T \cdot dY = A^T \cdot 1 = A^T$ ✓
手算结果和公式结果完全吻合。
八、扩展:常见变体
掌握了基础结论,其他情况都能直接推导:
| 前向 | 反向 |
|---|---|
| $Y = AB$ | $dA = dY \cdot B^T$,$dB = A^T \cdot dY$ |
| $Y = A^T B$ | $dA = B \cdot dY^T$,$dB = A \cdot dY$ |
| $Y = AB^T$ | $dA = dY \cdot B$,$dB = dY^T \cdot A$ |
| $Y = ABC$ | $dA = dY \cdot (BC)^T$,$dB = A^T \cdot dY \cdot C^T$,$dC = (AB)^T \cdot dY$ |
规律就是前面说的:每个因子的梯度 = 把公式里对应位置换成 $dY$,其他因子保持位置不变,但要转置。
九、在 Attention 里的具体应用
回到上一节,这三处都用了结论 1:
$O = PV \Rightarrow dV = P^T \cdot dO, dP = dO \cdot V^T$
- $P$ 在左,所以 $dV$ 公式里 $P$ 挪到左边并转置
- $V$ 在右,所以 $dP$ 公式里 $V$ 挪到右边并转置
$S = QK^T$,看成 $S = Q \cdot (K^T)$,所以:
- $dQ = dS \cdot (K^T)^T = dS \cdot K$
- $d(K^T) = Q^T \cdot dS$,再转置:$dK = dS^T \cdot Q$
这正是为什么那些公式"看起来对称但不完全对称"——$Q, K$ 在前向里不是对称进入的($K$ 带了转置),反向里自然也不对称。
十、总结
三种理解角度:
- 标量链式法则:逐元素写出 $Y_{ij} = \sum_p A_{ip} B_{pj}$,用多元链式法则加起来,识别出矩阵乘法结构
- 全微分 + 迹:$dL = \text{tr}((\partial L/\partial X)^T dX)$,用迹的循环性凑出梯度
- 形状对齐 / 位置对应:$dA$ 形状必须与 $A$ 一致,倒推转置位置;或者记"梯度公式就是前向公式中某个因子换成 $dY$"
三种方法殊途同归。日常使用推荐方法 3(最快),遇到陌生情况用方法 2(最严谨),方法 1 是理解背后机制的基础。
结论二: 设 $p = \text{softmax}(s)$,则 $dS = P \odot \big(dP - \text{rowsum}(dP \odot P)\big)$
结论 2:设 $p = \text{softmax}(s)$,其中 $s, p \in \mathbb{R}^N$,则:
扩展到按行 softmax 的矩阵形式:
下面从零开始推导。
一、softmax 的定义回顾
对向量 $s = (s_1, s_2, \dots, s_N) \in \mathbb{R}^N$,softmax 输出:
记分母 $Z = \sum_k e^{s_k}$,则 $p_i = e^{s_i} / Z$。
两个关键性质:
- $\sum_i p_i = 1$
- $p_i > 0$ 对所有 $i$
二、求雅可比矩阵 $\frac{\partial p_i}{\partial s_j}$
这是整个推导的核心。我们要算:$p_i$ 对每个 $s_j$ 的偏导。
由于 $p_i = e^{s_i} / Z$,而 $Z$ 依赖所有的 $s_j$(不只是 $s_i$),这是关键——改变任何一个 $s_j$ 都会通过 $Z$ 影响每个 $p_i$。
2.1 分情况推导
用商的求导法则:
先算两个关键偏导:
$\frac{\partial Z}{\partial s_j} = \frac{\partial}{\partial s_j}\sum_k e^{s_k} = e^{s_j}$(只有 $k = j$ 那一项有贡献)
其中 $\delta_{ij}$ 是 Kronecker delta:$i = j$ 时为 1,否则为 0。
2.2 代入整理
分子分母同除以 $Z^2$:
注意 $e^{s_i} / Z = p_i$,$e^{s_j} / Z = p_j$,所以:
2.3 分两种情况的形式
写得更直观一点:
对角上是 $p_i(1-p_i)$,非对角是 $-p_i p_j$。注意这是一个对称雅可比(因为 $p_i p_j = p_j p_i$)。
2.4 验证一下
简单验证:对角上 $p_i(1 - p_i) > 0$(softmax 的输出是增函数,增大 $s_i$ 会增大 $p_i$),非对角 $-p_i p_j < 0$(增大 $s_j$ 会减小别的 $p_i$),符合直觉。
还可以验证每列求和为 0:
这反映了 $\sum_i p_i = 1$ 这个约束——所有 $p_i$ 对任意 $s_j$ 的变化量加起来必须为 0。
三、雅可比的矩阵写法
把 $\frac{\partial p_i}{\partial s_j} = p_i \delta_{ij} - p_i p_j$ 整理成矩阵:
其中:
- $\text{diag}(p)$ 是对角矩阵,对角线上是 $p_1, p_2, \dots, p_N$(对应 $\delta_{ij} p_i$ 项)
- $p p^T$ 是外积矩阵,第 $(i,j)$ 元素是 $p_i p_j$
形状:$J \in \mathbb{R}^{N \times N}$,$J_{ij} = \frac{\partial p_i}{\partial s_j}$。
四、反向传播:求 $ds$
4.1 链式法则
给定上游梯度 $dp \in \mathbb{R}^N$($dp_i = \frac{\partial L}{\partial p_i}$),反向得到:
写成矩阵形式:
4.2 代入 $J$ 的表达式
逐项分析:
- $\text{diag}(p) \cdot dp$:对角矩阵左乘向量,等价于逐元素相乘,即 $p \odot dp$
- $p p^T dp$:先算 $p^T dp$,这是个标量(两个向量的点积),记为 $\rho = \sum_k p_k dp_k$。然后 $p \cdot \rho$ 是每个 $p_i$ 乘以这个标量。
所以:
4.3 因式分解形式
把 $p$ 提出来(每项都含 $p_i$):
写成向量:
这就是结论 2 的向量形式。$\rho$ 是个标量,被广播到 $dp$ 的每个位置做减法。
五、直觉理解:$ds$ 公式在说什么
这个公式有很清晰的几何/概率直觉。
5.1 $\rho = \sum_k p_k dp_k$ 是什么?
这是 $dp$ 在 $p$ 分布下的加权平均(因为 $\sum_k p_k = 1$,$p$ 本身是概率分布)。
5.2 $(dp - \rho)$ 是什么?
是 $dp$ 减去其加权平均,即"中心化"后的 $dp$。
5.3 为什么要中心化?
因为 softmax 有个根本约束:$\sum_i p_i = 1$,所以所有 $p_i$ 同时增加 1 单位 是不可能的——它们必须"此消彼长"。
从梯度角度看:如果 $dp$ 是一个常量向量(所有 $dp_i$ 相等),说明损失对 $p$ 的每个分量"要求"同样的变化,但 softmax 无法同时满足(受约束)。此时应该 $ds = 0$。
验证:若 $dp_i = c$(常量),则 $\rho = \sum_k p_k \cdot c = c$,所以 $dp - \rho = 0$,进而 $ds = 0$。完美符合直觉。
5.4 为什么乘 $p$?
中心化后,还要乘以 $p$(逐元素)。这反映了:$p_i$ 小的分量对 $s_i$ 的变化不敏感(因为 $p_i \approx 0$,改变 $s_i$ 影响很小)。
六、扩展到矩阵(按行 softmax)
在 attention 里,$P = \text{softmax}(S)$ 是按 $S$ 的每一行独立做 softmax。设 $S, P \in \mathbb{R}^{N \times N}$,则不同行之间互不影响——第 $i$ 行的 $P$ 只依赖第 $i$ 行的 $S$。
所以把前面的结论逐行应用即可。对第 $i$ 行:
写成矩阵形式:
其中 $\text{rowsum}(\cdot)$ 返回一个 $N \times 1$ 的列向量,每一行是对应行的和。减法时这个列向量被广播到整行(每行都减同一个标量)。
七、一个具体例子手算验证
取 $N = 2$,$s = (1, 2)$。
前向:
带入数值($e \approx 2.718$):$p_1 \approx 0.269$,$p_2 \approx 0.731$。
雅可比:
(注意到 $p_1 p_2 = p_1(1-p_1) = p_2(1-p_2)$,因为 $N=2$ 时 $p_2 = 1 - p_1$。)
假设上游梯度 $dp = (1, 0)$(即 $L = p_1$)。
用公式算:
- $\rho = p^T dp = p_1 \cdot 1 + p_2 \cdot 0 = p_1 \approx 0.269$
- $dp - \rho = (1 - 0.269, 0 - 0.269) = (0.731, -0.269)$
- $ds = p \odot (dp - \rho) = (0.269 \times 0.731, 0.731 \times (-0.269)) \approx (0.197, -0.197)$
直接用雅可比:
- $ds = J^T dp = J dp = (0.197, -0.197)$ ✓
再直接对 $p_1$ 求偏导验证:
- $\frac{\partial p_1}{\partial s_1} = p_1(1-p_1) \approx 0.197$ ✓
- $\frac{\partial p_1}{\partial s_2} = -p_1 p_2 \approx -0.197$ ✓
三种方法结果一致,公式正确。
八、为什么这个结果特别好?
在 attention 的反向里,这个公式有几个重要的工程价值:
8.1 $ds$ 可以逐行、逐元素算
公式 $dS_{ij} = P_{ij}(dP_{ij} - \rho_i)$ 是一个逐元素操作(加上每行一个共享的 $\rho_i$)。
这意味着如果把 $P, dP$ 分块,每个 $(i, j)$ 块可以独立算 $dS_{ij}$——只需要知道对应行的 $\rho_i$ 这一个标量。这正是 FlashAttention 能分块反向的基础。
8.2 $\rho_i$ 可以预计算或替换
$\rho_i = \sum_k P_{ik} dP_{ik}$ 看似需要完整的 $P$ 和 $dP$,但上一节的恒等式告诉我们:
所以 $D_i$ 可以在反向循环开始前一次性算好(只用 $dO$ 和 $O$),每个 block 需要时直接读取。
8.3 避免显式构造雅可比
若老老实实构造 $J = \text{diag}(p) - pp^T$,需要 $O(N^2)$ 存储一个稠密矩阵,然后做 $Jdp$ 这个矩阵-向量乘法 $O(N^2)$ 次运算。
用因式分解形式 $ds = p \odot (dp - \rho)$:
- 算 $\rho$:$O(N)$ 次乘加
- 减去 $\rho$:$O(N)$ 次减法
- 逐元素乘 $p$:$O(N)$ 次乘法
总共 $O(N)$,不需要显式构造雅可比矩阵。这在 $N$ 大(长序列)时是关键的优化。
九、另一种推导路径:用交叉熵直觉
如果你学过分类问题里的 softmax + 交叉熵,可能见过一个著名公式:
这个公式看起来和我们推的完全不同,但其实是一致的——它是特殊情况。
当损失是交叉熵 $L = -\sum_i y_i \log p_i$ 时:
代入通用公式 $ds = p \odot (dp - \rho)$:
一致!这也再次验证了通用公式的正确性。交叉熵之所以那么简洁,是因为它的 $dp$ 形式特别,让 $\rho = -1$ 消得很干净。
十、总结
核心结果:
三层理解:
- 数学层面:softmax 的雅可比是 $\text{diag}(p) - pp^T$,利用这个矩阵的特殊结构(一个对角加一个秩一),$Jdp$ 有 $O(N)$ 的分解形式
- 几何层面:$\rho$ 是 $dp$ 在 $p$ 分布下的加权平均,$dp - \rho$ 是中心化,反映了 softmax 的"概率和为 1"约束——共同平移不改变 softmax
- 工程层面:分解形式省时省空间,可以逐行、逐块计算,让 FlashAttention 的分块反向成为可能;$\rho$ 可以预计算或用 $dO \odot O$ 替换,彻底绕过 $N \times N$ 的 $P, dP$
在 FlashAttention 中的角色:这个结论加上结论 1(矩阵乘的梯度)和 $\rho = \text{rowsum}(dO \odot O)$ 的恒等式,构成了 attention 反向传播分块计算的完整数学基础。
反向传播的公式推导
这个推导是理解 FlashAttention 反向的基础。我会从最基本的链式法则开始,一步步把所有梯度推出来。
一、前向回顾与记号约定
前向传播:
其中 $Q, K, V \in \mathbb{R}^{N \times d}$(为简洁省略 $\sqrt{d}$ 缩放,加回去只是每处乘个常数)。
记号约定(这是关键,很多推导让人糊涂就是记号乱):
即 $dX$ 表示损失 $L$ 对 $X$ 的梯度,形状与 $X$ 完全相同。
已知:$dO \in \mathbb{R}^{N \times d}$(从上游传下来)。
目标:求 $dQ, dK, dV$,形状分别与 $Q, K, V$ 相同。
二、两个基础结论
推导矩阵梯度时,以下两个结论会被反复使用。先把它们证明清楚,后面直接套用。
结论 1:若 $Y = AB$,则
推导:标量形式 $Y_{ij} = \sum_k A_{ik} B_{kj}$。
$\frac{\partial Y_{i'j}}{\partial A_{ik}} = B_{kj} \cdot \mathbb{1}[i' = i]$,所以
同理得 $dB = A^T \cdot dY$。一个助记方法:梯度的形状必须和原矩阵一致,所以 $B$ 出现在右边就要转置,$A$ 出现在左边也要转置。
结论 2:行 softmax 的雅可比
设 $p = \text{softmax}(s)$,其中 $s, p \in \mathbb{R}^N$(单独一行),则给定 $dp$,
推导:softmax 的雅可比
所以
写成向量形式:$ds = p \odot dp - p \cdot (\sum_i p_i dp_i)$。
令 $\rho = \sum_i p_i dp_i$(这是个标量,即 $p$ 和 $dp$ 的点积),则:
扩展到整个矩阵 $P$(每行独立做 softmax),逐行应用上式,$\rho$ 变成每行一个的列向量,广播到对应行:
这里 $\text{rowsum}(\cdot)$ 返回一个 $N \times 1$ 的列向量,广播时每一行减去自己那行的标量。
三、逐步推导三个梯度
步骤 1:从 $O = PV$ 出发,求 $dV$ 和 $dP$
套用结论 1(令 $Y = O, A = P, B = V$):
形状检查:
- $dV$:$(N \times N)^T \cdot (N \times d) = (N \times d)$ ✓
- $dP$:$(N \times d) \cdot (d \times N) = (N \times N)$ ✓
步骤 2:从 $P = \text{softmax}(S)$,求 $dS$
直接套用结论 2:
形状检查:$N \times N$ ✓
步骤 3:从 $S = QK^T$,求 $dQ$ 和 $dK$
把 $S = QK^T$ 写成 $S = Q \cdot (K^T)$,套用结论 1(令 $Y = S, A = Q, B = K^T$):
对 $B = K^T$:
所以
形状检查:
- $dQ$:$(N \times N) \cdot (N \times d) = (N \times d)$ ✓
- $dK$:$(N \times N) \cdot (N \times d) = (N \times d)$ ✓
四、整合全流程
把五个公式串起来:
标准实现里,每一步都要在 HBM 中保存 $N \times N$ 的 $P, dP, dS$,显存 $O(N^2)$,且反复读写导致速度慢。
五、一个关键恒等式:$D$ 可以只用 $dO$ 和 $O$ 算
FlashAttention 的一个优化点在于:那个 $D = \text{rowsum}(dP \odot P)$ 看起来需要完整的 $dP$ 和 $P$,但其实可以完全绕过它们。
断言:
推导:注意 $O = PV$,$dP = dO \cdot V^T$。逐元素算:
对 $j$ 求和:
里面的 $\sum_j P_{ij} V_{jk} = (PV)_{ik} = O_{ik}$,于是:
即:
这个结论很重要,因为它意味着:
- $D \in \mathbb{R}^{N \times 1}$ 可以在反向循环开始之前用 $dO$ 和 $O$ 一次性算好
- 每个 block 在需要 $D_i$ 时直接读取,不需要 $P$ 或 $dP$ 的全貌
- 让分块反向在数学上变得干净:每个 block 所需的外部输入只有 $Q_i, K_j, V_j, dO_i, O_i, L_i, D_i$,全都是 $O(N)$ 或 $O(Nd)$ 的量
六、FlashAttention 分块反向的完整流程
有了上面的公式,可以把整个反向重写为分块形式:
1 | 预处理: D = rowsum(dO ⊙ O) # 形状 N×1, 一次扫描搞定 |
注意 $P_{ij}, dP_{ij}, dS_{ij}$ 都是 $B_r \times B_c$ 的小块,只活在 SRAM 里,计算完就丢。整个反向期间 HBM 中从未出现完整的 $N \times N$ 矩阵。
七、一句话总结
| 推导环节 | 用到的工具 |
|---|---|
| $dV, dP$ | 矩阵乘反向(结论 1) |
| $dS$ | softmax 雅可比(结论 2) |
| $dQ, dK$ | 矩阵乘反向(结论 1) |
| $D$ 的简化 | 交换求和顺序 + $O = PV$ 恒等式 |
最核心的两个技术点:
- 结论 2 的 softmax 反向 把 $dS$ 和 $P \odot dP$ 联系起来
- $D = \text{rowsum}(dO \odot O)$ 的恒等式 让整个反向可以完全分块,不需要 $P, dP$ 的全貌
这两点加上前向存下来的 $L$,让 FlashAttention 反向做到了数学上与标准实现完全等价,但显存 $O(N)$、速度更快。