1.FlashAttention 整体结构
FlashAttention 是由斯坦福大学 Tri Dao 等人在 2022 年提出的一种精确(非近似)的注意力计算算法。它通过重新设计注意力的计算方式,大幅降低了显存占用并加速了训练与推理,现已成为现代大模型(如 GPT、LLaMA 等)的标配。
一、问题背景:标准 Attention 的瓶颈
在 Transformer 中,自注意力的核心计算是:
其中 $Q, K, V \in \mathbb{R}^{N \times d}$,$N$ 是序列长度,$d$ 是每个头的维度。
标准实现的做法(PyTorch 默认的 naive 实现):
- 计算 $S = QK^T$,得到一个 $N \times N$ 的矩阵,写回 HBM(显存)
- 计算 $P = \text{softmax}(S)$,再次写回 HBM
- 计算 $O = PV$,写回 HBM
两个核心问题:
- 显存占用 $O(N^2)$:当 $N=8192$ 时,单个注意力矩阵就要 256MB(fp32),多头多层叠加显存爆炸
- 速度瓶颈不是算力,而是显存带宽:GPU 的算力增长远快于显存带宽,注意力计算是典型的 memory-bound(访存受限)问题,大量时间花在 HBM 读写上
二、GPU 内存层级的关键认知
理解 FlashAttention 必须先理解 GPU 的内存结构:
| 层级 | 容量 | 带宽 | 速度 |
|---|---|---|---|
| HBM(全局显存) | 40–80 GB | ~1.5–3 TB/s | 慢 |
| SRAM(片上共享内存) | ~20 MB(A100 全部 SM 合计) | ~19 TB/s | 快约 10 倍 |
标准 attention 反复在 HBM 中读写大矩阵,SRAM 几乎没被利用。FlashAttention 的核心思路就是:尽可能把计算留在 SRAM 里,避免往 HBM 写中间结果。
三、核心思想:Tiling + Online Softmax + Recomputation
1. Tiling(分块计算)
把 $Q, K, V$ 沿序列维度切成小块(block),每次只把一小块加载到 SRAM 中计算,这样 $N \times N$ 的大矩阵永远不会被完整地实例化。
假设将 $Q$ 切成 $T_r$ 块,$K, V$ 切成 $T_c$ 块,算法变成一个双层循环: 外层遍历 $Q$ 的块,内层遍历 $K, V$ 的块。
2. Online Softmax(增量式 softmax)
最大的技术难点是: softmax 需要看到整行所有元素才能归一化(要算分母 $\sum e^{x_i}$),而分块计算时一次只能看到一部分。FlashAttention 采用了 Milakov & Gimelshein 提出的 online softmax 技巧。
为了数值稳定,softmax 通常减去最大值:
当新的一块数据进来,当前块最大值是 $m^{new}$,则可以这样更新:
输出 $O$ 也可以用类似的 rescaling 进行增量更新:
这样每处理完一个 $K, V$ 块,就用缩放因子修正之前累积的结果,数学上与标准 softmax 完全等价,精度没有任何损失。
3. Recomputation(反向传播时重算)
在反向传播中,原本需要保存 $N \times N$ 的注意力矩阵 $P$ 来算梯度。FlashAttention 只保存 softmax 的统计量(每行的最大值 $m$ 和归一化因子 $\ell$,各 $O(N)$),反向时重新计算 $S$ 和 $P$。
看起来增加了计算量,但由于 attention 是 memory-bound 的,省下来的 HBM 访问时间远多于重算的时间,整体反而更快。
四、复杂度分析
设 $M$ 是 SRAM 大小:
| 指标 | 标准 Attention | FlashAttention |
|---|---|---|
| 计算量(FLOPs) | $O(N^2 d)$ | $O(N^2 d)$ |
| HBM 访问量 | $O(Nd + N^2)$ | $O(N^2 d^2 / M)$ |
| 显存占用 | $O(N^2)$ | $O(N)$ |
计算量没变,但 HBM 访问量在 $d \ll M$ 的常见情况下大幅降低($d$ 一般是 64 或 128,$M$ 在 A100 上约 100KB 级别),这是加速的根本来源。
3.Online Softmax(增量式 softmax)
Online Softmax 是 FlashAttention 能成立的数学基石。它由 NVIDIA 的 Milakov & Gimelshein 在 2018 年的论文 “Online normalizer calculation for softmax” 中提出,核心思想是:在只看过部分数据的情况下,增量地计算数值稳定的 softmax,并在新数据到来时做修正。
下面我会从标准 softmax 的问题开始,一步步推导到 online 版本,并解释它在 FlashAttention 中的角色。
一、标准 Softmax 与数值稳定性问题
1.1 朴素定义
给定向量 $x = (x_1, x_2, \dots, x_N)$,softmax 定义为:
1.2 数值溢出问题
如果某个 $x_i$ 较大(比如 1000),$e^{x_i}$ 会溢出变成 inf。即使在 fp16 下,$x_i > 11$ 就会溢出($e^{11} \approx 59874$,fp16 的上限约 65504)。
1.3 Safe Softmax(3-pass 版本)
解决方法是减去最大值:
这个变换数学上完全等价(分子分母同时乘以 $e^{-m}$),但把最大指数限制在 0,保证 $e^{x_i - m} \in (0, 1]$,不会溢出。
算法变成 3 遍扫描:
1 | Pass 1: 求最大值 m |
问题:需要 3 次读取整个 $x$ 向量,对 GPU 这种 memory-bound 场景非常不友好。如果 $x$ 太大放不进 SRAM,就得反复从 HBM 读 3 次。
二、Online Softmax 的核心想法
目标:把 Pass 1 和 Pass 2 合并成一次扫描,也就是边扫描边同时维护 $m$ 和 $\ell$。
难点在于:$\ell$ 的定义依赖 $m$,而 $m$ 要扫完整个数组才能确定。扫到一半时的”局部最大值”可能比真正的最大值小,之前算出来的 $\ell$ 就是错的。
关键洞察:如果有办法在 $m$ 发生变化时修正 $\ell$,就可以一边扫一边算。
2.1 推导修正公式
假设已经处理了前 $i$ 个元素,维护着:
- $m_i$ = 前 $i$ 个元素中的最大值
- $\ell_i = \sum_{j=1}^{i} e^{x_j - m_i}$ = 基于当前最大值的归一化因子
现在来了第 $i+1$ 个元素 $x_{i+1}$,新的最大值是:
要得到新的 $\ell_{i+1} = \sum_{j=1}^{i+1} e^{x_j - m_{i+1}}$,观察:
对旧部分做一个简单变换:
于是得到更新公式:
几何直觉: 每次最大值增大(即 $m_{i+1} > m_i$),之前累积的 $\ell_i$ 是基于”旧基准” $m_i$ 的,现在基准变大了,所有旧项都得乘以 $e^{m_i - m_{i+1}} < 1$ 进行缩小(rescale)。如果最大值没变($m_{i+1} = m_i$),缩放因子是 $e^0 = 1$,什么都不用调整。
2.2 2-pass 算法
1 | Pass 1: 同时计算 m 和 ℓ |
从 3 遍扫描降到了 2 遍,而且数值稳定性与 safe softmax 完全相同(始终保持减去当前最大值)。
三、分块(Block-wise)版本:FlashAttention 的真正需求
单元素的更新公式只是开胃菜。FlashAttention 需要按块处理——一次性算完一个块的局部 $m$ 和 $\ell$,再和已有的累积量合并。
3.1 两块合并公式
设已经处理完块 A,维护着 $(m_A, \ell_A)$; 现在处理块 B,独立算出 $(m_B, \ell_B)$,其中:
合并后的全局量 $(m, \ell)$ 应为:
这就是 FlashAttention 里那个看起来眼熟的公式。推导完全同上:两块都要把基准从自己的局部最大值改到全局最大值,各自乘以对应的 $e^{m_{\text{local}} - m_{\text{global}}}$ 缩放因子。
注意这个合并操作是满足结合律的——合并 A、B 再合并 C,和先合并 B、C 再和 A 合并,结果一致。这使得它能很自然地并行化(比如 tree reduction)。
四、扩展:同时在线更新输出 $O$
在普通 softmax 里,Pass 2 要重新扫一遍算 $y_i$。但在 attention 里,我们最终要的不是 softmax 本身,而是 $O = PV$,其中 $P = \text{softmax}(S)$。FlashAttention 需要把 $O$ 也一起在线累积,这样整个算法变成 1-pass。
4.1 定义未归一化的输出
定义”未归一化”的输出(只做了减最大值但没除以 $\ell$):
处理块 B 时同理:
合并后全局未归一化输出应为 $\tilde{O} = \sum_{j \in A \cup B} e^{x_j - m} V_j$,同样需要基准变换:
最后一步再除以 $\ell$ 得到真正的 $O = \tilde{O} / \ell$。
4.2 伪代码(FlashAttention 核心)
1 | m = -∞, ℓ = 0, O = 0 # O 此处表示未归一化的 Õ |
这样,整个 attention 只需一次扫描 $K, V$,完全不需要实例化 $N \times N$ 的注意力矩阵。FlashAttention-2 的一个小优化就是把每一步的 / ℓ 推迟到最后,避免反复做除法。
五、为什么数学上精确?
这点很多人初学时会怀疑:”边扫边更新,最后结果真的和一次性算一样吗?”
答案是完全一样,因为每次合并时:
- $m$ 始终是当前已见元素的真实最大值
- $\ell$ 通过缩放因子修正后,始终等于 $\sum_{j \in \text{已见}} e^{x_j - m_{\text{当前}}}$
- $\tilde{O}$ 同理
扫描结束时,$m$ 等于全局最大值,$\ell$ 等于全局归一化因子,$\tilde{O} / \ell$ 等于标准 attention 的输出。这和”提前知道 $m$ 再一口气算”在代数上完全等价——没有任何近似。
六、一个数值例子
以 $x = [1, 3, 2, 5]$ 为例,手工走一遍:
标准做法: $m = 5$,$\ell = e^{-4} + e^{-2} + e^{-3} + e^0 \approx 0.0183 + 0.1353 + 0.0498 + 1 = 1.2034$
Online 做法:
| 步骤 | 新元素 | $m_{\text{new}}$ | 缩放因子 $e^{m_{\text{old}} - m_{\text{new}}}$ | $\ell$ 更新 |
|---|---|---|---|---|
| 初始 | — | $-\infty$ | — | $\ell = 0$ |
| 步 1 | 1 | 1 | $e^{-\infty - 1} = 0$ | $\ell = 0 \cdot 0 + e^0 = 1$ |
| 步 2 | 3 | 3 | $e^{1 - 3} = 0.1353$ | $\ell = 1 \cdot 0.1353 + e^0 = 1.1353$ |
| 步 3 | 2 | 3 | $e^{3 - 3} = 1$ | $\ell = 1.1353 \cdot 1 + e^{-1} = 1.1353 + 0.3679 = 1.5032$ |
| 步 4 | 5 | 5 | $e^{3 - 5} = 0.1353$ | $\ell = 1.5032 \cdot 0.1353 + e^0 = 0.2034 + 1 = 1.2034$ ✓ |
最后得到 $\ell = 1.2034$,与标准做法完全一致。
七、工程意义
Online softmax 看似只是一个小技巧,实际上解锁了一整类算法:
- FlashAttention — 让 attention 从 $O(N^2)$ 显存降到 $O(N)$
- Ring Attention — 把序列分片到多 GPU,每个 GPU 持一部分 $K, V$,通过环形通信传递部分结果,用 online softmax 合并。支持百万级 token 序列
- 流式推理 / Prefix Caching — 先算一部分 prefix 的 softmax 状态,后续 token 来时增量更新
- 分布式 softmax — 跨设备的 softmax 归一化
它的思想可以抽象为:任何 “reduce + 归一化”的操作,只要能找到带修正因子的合并律,就能分块并行计算。这一点也和 log-sum-exp 的并行化密切相关(实际上 $\log \ell + m$ 就是 log-sum-exp,online softmax 等价于 log-sum-exp 的在线计算)。