设定一个具体例子 为了让矩阵形状清晰,我们用一个小模型:
1 2 3 4 d (hidden dim) = 8 h (num heads) = 2 d_h (head dim) = 4 (= d / h) B (batch size) = 1 (为简化,省略 batch 维)
Prompt:"The cat sat",假设 tokenize 后是 3 个 token,记为 x₁, x₂, x₃。
目标 :先 prefill 处理 prompt,然后 decode 逐个生成 x₄, x₅, x₆。
阶段 1:Prefill(一次性处理 3 个 token) Step 1.1:Embedding 输入 token ids [t₁, t₂, t₃] 经过 embedding 层:
1 2 3 X = [ x₁ ] 形状 [3, 8],即 [seq_len=3, d=8] [ x₂ ] [ x₃ ]
每一行是一个 token 的 8 维向量。
Step 1.2:QKV 投影 用三个权重矩阵 W_Q, W_K, W_V,每个形状 [d=8, d=8]:
1 2 3 Q = X @ W_Q → [3, 8] @ [8, 8] = [3, 8] K = X @ W_K → [3, 8] @ [8, 8] = [3, 8] V = X @ W_V → [3, 8] @ [8, 8] = [3, 8]
具体地(每行代表一个 token 的 Q/K/V 向量):
1 2 3 Q = [ q₁ ] K = [ k₁ ] V = [ v₁ ] [ q₂ ] [ k₂ ] [ v₂ ] [ q₃ ] [ k₃ ] [ v₃ ]
Step 1.3:Reshape 成多头 把最后一维 d=8 拆成 h=2 个 head,每个 head d_h=4,然后转置让 head 维度在前:
1 2 3 Q: [3, 8] → [3, 2, 4] → [2, 3, 4] ([h, seq, d_h]) K: [3, 8] → [3, 2, 4] → [2, 3, 4] V: [3, 8] → [3, 2, 4] → [2, 3, 4]
现在每个 head 独立处理。为了讲清楚,只看 head 0 (head 1 完全对称):
1 2 3 4 5 6 7 8 9 10 11 Q^(0) = [ q₁ ] 形状 [3, 4] [ q₂ ] (每行是一个 token 在 head 0 上的 4 维 query) [ q₃ ] K^(0) = [ k₁ ] 形状 [3, 4] [ k₂ ] [ k₃ ] V^(0) = [ v₁ ] 形状 [3, 4] [ v₂ ] [ v₃ ]
Step 1.4:Attention Score = Q @ K^T 这是关键的一步。Q @ K^T 计算每个 query 和每个 key 的相似度:
1 Q^(0) @ (K^(0))^T : [3, 4] @ [4, 3] = [3, 3]
形象地画出来:
1 2 3 4 5 6 7 8 9 10 11 K^T (transposed) ┌────────────────┐ │ k₁ᵀ k₂ᵀ k₃ᵀ │ 每列是一个 key └────────────────┘ Q ↓ 矩阵乘 ┌──────┐ ┌──────────────────────┐ │ q₁ │ → │ q₁·k₁ q₁·k₂ q₁·k₃ │ ← q₁ 与所有 key 的点积 │ q₂ │ → │ q₂·k₁ q₂·k₂ q₂·k₃ │ ← q₂ 与所有 key 的点积 │ q₃ │ → │ q₃·k₁ q₃·k₂ q₃·k₃ │ ← q₃ 与所有 key 的点积 └──────┘ └──────────────────────┘ S 形状 [3, 3]
S[i][j] = qᵢ · kⱼ,表示 token i 关注 token j 的”原始注意力分数”。
Step 1.5:Causal Mask + Softmax 由于是 decoder,token i 不能看 token j > i ,所以右上三角要被 mask 成 -∞:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 S 加 mask 后: ┌─────────────────────┐ │ q₁·k₁ -∞ -∞ │ ← token 1 只能看 token 1 │ q₂·k₁ q₂·k₂ -∞ │ ← token 2 只能看 token 1, 2 │ q₃·k₁ q₃·k₂ q₃·k₃│ ← token 3 可以看 token 1, 2, 3 └─────────────────────┘ 除以 √d_h = 2,然后 softmax(按行): A = ┌──────────────────────┐ │ 1.00 0 0 │ token 1 完全 attend 自己 │ 0.4 0.6 0 │ token 2 在 token 1, 2 间分布 │ 0.2 0.3 0.5 │ token 3 在 token 1, 2, 3 间分布 └──────────────────────┘ 形状仍是 [3, 3]
每行和为 1。这就是 attention weights。
Step 1.6:加权求和 A @ V 用注意力权重对 V 加权求和:
1 A @ V^(0): [3, 3] @ [3, 4] = [3, 4]
形象画法:
1 2 3 4 5 6 7 8 9 10 11 12 13 V ┌──────┐ │ v₁ │ ← 4 维 │ v₂ │ │ v₃ │ └──────┘ A ↓ 加权求和 ┌──────────────┐ ┌──────────────────────┐ │ 1.0 0 0 │ → │ 1.0·v₁ + 0·v₂ + 0·v₃│ = output₁ │ 0.4 0.6 0 │ → │ 0.4·v₁ + 0.6·v₂ │ = output₂ │ 0.2 0.3 0.5 │ → │ 0.2·v₁ + 0.3·v₂+0.5·v₃│=output₃ └──────────────┘ └──────────────────────┘ O^(0) 形状 [3, 4]
每一行 outputᵢ 是 token i 的 context-aware 表示 ——它是历史所有 token 的 value 向量的加权混合。
Step 1.7:合并多头 + 输出投影 1 2 3 4 O^(0): [3, 4] ┐ ├→ concat → [3, 8] → @ W_O [8, 8] → [3, 8] O^(1): [3, 4] ┘ ↑ 最终 attention 输出
然后经过 residual、LayerNorm、FFN、再 residual……到下一层。
Step 1.8:Prefill 的关键产物——KV Cache 这一层结束时,把 K, V 存进 cache :
1 2 3 4 5 6 7 8 KV_Cache[layer_i]: K_cache = [ k₁ ] 形状 [h=2, seq=3, d_h=4] [ k₂ ] [ k₃ ] V_cache = [ v₁ ] 形状 [h=2, seq=3, d_h=4] [ v₂ ] [ v₃ ]
每一层都有自己独立的 KV cache。
Step 1.9:取最后一个 token 预测 x₄ 经过所有层后得到 [3, 8],取最后一行 [1, 8],过 LM head [8, vocab],softmax 采样得到 x₄。
Prefill 完成。计算特点:3 个 token 一次性算,是一个标准的矩阵乘 GEMM,计算密集。
阶段 2:Decode(逐 token 生成 x₄) 现在我们要算 x₄ 经过 attention 的输出。关键问题:x₄ 的 query 需要 attend 到 x₁, x₂, x₃, x₄ 的 key/value。 但 x₁, x₂, x₃ 的 K/V 已经在 cache 里了!
Step 2.1:只输入 1 个新 token 1 X_new = [ x₄ ] 形状 [1, 8]
注意:只有 1 行 ,不是 4 行。这是和 prefill 最大的区别。
Step 2.2:QKV 投影(只算新 token) 1 2 3 q_new = X_new @ W_Q → [1, 8] @ [8, 8] = [1, 8] k_new = X_new @ W_K → [1, 8] @ [8, 8] = [1, 8] v_new = X_new @ W_V → [1, 8] @ [8, 8] = [1, 8]
注意 FLOPs 上的省力 :投影从 [3, 8] @ [8, 8] 变成 [1, 8] @ [8, 8],省了 3 倍。 更重要的是:这不再是矩阵-矩阵乘(GEMM),而是矩阵-向量乘(GEMV) ——这是 decode 变成 memory-bound 的根因。
Step 2.3:Reshape 多头 1 2 3 q_new: [1, 8] → [1, 2, 4] → [2, 1, 4] ([h, 1, d_h]) k_new: [1, 8] → [1, 2, 4] → [2, 1, 4] v_new: [1, 8] → [1, 2, 4] → [2, 1, 4]
只看 head 0:
1 2 3 q_new^(0) = [ q₄ ] 形状 [1, 4] k_new^(0) = [ k₄ ] 形状 [1, 4] v_new^(0) = [ v₄ ] 形状 [1, 4]
Step 2.4:Cache 更新——把 k₄, v₄ 拼接进去 1 2 3 4 5 6 7 8 9 10 11 12 更新前: K_cache^(0) = [ k₁ ] 形状 [3, 4] [ k₂ ] [ k₃ ] 更新后(concat 在 seq 维): K^(0) = [ k₁ ] 形状 [4, 4] ← 多了一行 k₄ [ k₂ ] [ k₃ ] [ k₄ ] ← 刚算出的 V^(0) 同理:[4, 4]
实际实现中,cache 是预分配的 [h, max_len, d_h],写入是 K_cache[:, current_pos, :] = k_new,没有真的 concat。
Step 2.5:Attention Score——形状的关键变化 1 q_new^(0) @ (K^(0))^T : [1, 4] @ [4, 4] = [1, 4]
注意!这里左操作数只有 1 行 :
1 2 3 4 5 6 7 8 9 K^T ┌────────────────────┐ │ k₁ᵀ k₂ᵀ k₃ᵀ k₄ᵀ│ 4 列 └────────────────────┘ q_new ↓ ┌──────┐ ┌──────────────────────────┐ │ q₄ │ → │ q₄·k₁ q₄·k₂ q₄·k₃ q₄·k₄│ └──────┘ └──────────────────────────┘ s 形状 [1, 4]
只有 1 行 score ,因为我们只关心 q₄ 对所有历史 key 的 attention。
Step 2.6:不需要 Causal Mask! 为什么?因为 q₄ 本来就是”最新的 token”,它能看到所有历史是合法的,而 cache 里根本不存在 k₅, k₆ (未来的 token 还没生成)。Mask 的作用是阻止看未来——但未来此刻还不存在。
1 2 3 直接除以 √d_h,softmax: a = [ a₁ a₂ a₃ a₄ ] 形状 [1, 4],和为 1
Step 2.7:加权求和 a @ V 1 a^(0) @ V^(0) : [1, 4] @ [4, 4] = [1, 4]
形象画法:
1 2 3 4 5 6 7 8 9 10 11 12 V (包括 cache 和 v₄) ┌──────┐ │ v₁ │ │ v₂ │ │ v₃ │ │ v₄ │ └──────┘ a ↓ ┌─────────────────┐ ┌────────────────────────────────┐ │ a₁ a₂ a₃ a₄ │ → │ a₁·v₁ + a₂·v₂ + a₃·v₃ + a₄·v₄│ └─────────────────┘ └────────────────────────────────┘ output₄ 形状 [1, 4]
只得到 1 个输出向量 ——就是 x₄ 经过这一层 attention 后的表示。
Step 2.8:后续 + 下一步 output^(0), output^(1) concat 成 [1, 8] → W_O → 后续层 → LM head → 采样得到 x₅。
然后回到 Step 2.1,把 x₅ 作为新的 X_new,cache 长度变成 5,继续……
用一张图对比 Prefill vs Decode 的矩阵形状 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 ┌───────────────────────┬───────────────────────┐ │ PREFILL │ DECODE │ │ (3 tokens 一起) │ (1 token 一步) │ ├───────────────────────┼───────────────────────┤ │ │ │ 输入 X │ [3, 8] │ [1, 8] │ │ │ │ Q │ [2, 3, 4] │ [2, 1, 4] │ │ ↑全部新算 │ ↑只算新的 │ │ │ │ K │ [2, 3, 4] │ [2, 4, 4] │ │ ↑全部新算→存cache │ ↑3 来自 cache+1新│ │ │ │ V │ [2, 3, 4] │ [2, 4, 4] │ │ │ │ Q @ K^T │ [3,4]@[4,3] = [3,3] │ [1,4]@[4,4] = [1,4] │ │ 方阵,需要 causal mask│ 长条,不需要 mask │ │ │ │ A @ V │ [3,3]@[3,4] = [3,4] │ [1,4]@[4,4] = [1,4] │ │ │ │ 计算类型 │ GEMM(矩阵×矩阵) │ GEMV(向量×矩阵) │ 瓶颈 │ Compute-bound │ Memory-bound │ └───────────────────────┴───────────────────────┘
几个最容易混淆的点 1. 为什么 Q 不需要 cache? 因为生成时每步只用”最新位置的 query”——历史 token 的 query 算出来后只用于一次性产生那一步的 output,之后不再需要。而 K, V 会被未来所有 step 反复 attend 到,所以必须缓存。
2. 为什么 Decode 的 Q@K^T 不需要 mask? Prefill 时是 [3, 3] 的方阵,每行代表不同的 query,要阻止靠前的 query 看到靠后的 key,所以要 mask。Decode 时 [1, L] 只有一行 query,是”最新的”,本来就该看所有历史,而未来的 key 还不存在,所以天然不需要 mask。
3. Cache 长度何时增加? 每个 decode step 增加 1。生成 100 个 token,cache 从 L₀ 增长到 L₀+100。这就是为什么长生成会越来越慢——q@K^T 中 K 的长度在变长。
4. 一次只算一行 Q——是不是浪费 GPU? 是的!这正是 decode memory-bound 的根本原因。GPU 擅长 GEMM 但 GEMV 算术强度低。所以才有 continuous batching (把多个请求的 q_new 堆成 [batch, 1, d],让矩阵长起来)和 speculative decoding (一次猜多个 token 让 q 行数变大)这些优化。
把整个过程串起来的伪代码 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 def generate (prompt_ids, max_new=100 ): X = embed(prompt_ids) kv_cache = [] for layer in layers: Q = X @ layer.W_Q K = X @ layer.W_K V = X @ layer.W_V S = Q @ K.transpose() / sqrt(d_h) S = S + causal_mask A = softmax(S) out = A @ V kv_cache.append((K, V)) X = post_attention(out, layer) next_token = sample(lm_head(X[-1 ])) for step in range (max_new): X = embed([next_token]) for i, layer in enumerate (layers): q = X @ layer.W_Q k = X @ layer.W_K v = X @ layer.W_V K_full = concat(kv_cache[i].K, k) V_full = concat(kv_cache[i].V, v) kv_cache[i] = (K_full, V_full) s = q @ K_full.transpose() / sqrt(d_h) a = softmax(s) out = a @ V_full X = post_attention(out, layer) next_token = sample(lm_head(X[-1 ]))