2.3 Self-Attention机制深入理解
Self-Attention 是 Transformer 的心脏,也是当代大模型中计算量最集中、优化手段最丰富的模块。无论你是想理解 FlashAttention 背后的 IO 优化思想,还是想搞清楚 GQA、MLA 这些 Attention 变种为什么能减少推理开销,都绕不开对 Self-Attention 机制的深入理解。本文将从 Attention 的历史起源讲起,逐步拆解 Scaled Dot-Product Attention 的每一步数学原理,手写 PyTorch 实现,分析计算瓶颈,最后延伸到 FlashAttention 和各种 Attention 变种,力求让读者建立从直觉到公式再到工程实现的完整认知链条。
📑 目录
- 1. Attention 的历史演进
- 2. 从信息检索角度理解 QKV
- 3. Scaled Dot-Product Attention 详解
- 4. PyTorch 手动实现
- 5. 详细维度推导
- 6. 为什么要除以 $\sqrt{d_k}$:方差稳定性的数学推导
- 7. Softmax 的数值稳定性
- 8. Multi-Head Attention
- 9. MHA vs MQA vs GQA vs MLA
- 10. Causal Mask 的作用与实现
- 11. Self-Attention 的计算瓶颈分析
- 12. FlashAttention 的核心思想
- 自我检验清单
- 参考资料
1. Attention 的历史演进
在 2017 年 Transformer 横空出世之前,Attention 机制已经在序列建模领域酝酿了好几年。理解它的演进脉络,有助于我们把握 Self-Attention 到底解决了什么问题,以及为什么它的设计长成今天这个样子。
1.1 Bahdanau Attention (2014):开山之作
传统的 Encoder-Decoder 架构(比如用 RNN 做机器翻译)有一个致命缺陷:Encoder 把整个输入序列压缩成一个固定长度的向量,再交给 Decoder 去生成输出。当输入序列很长时,这个固定向量根本装不下所有信息,翻译质量会急剧下降。
Bahdanau 等人在 2014 年提出了一个关键改进:别硬压缩了,让 Decoder 在生成每个词的时候,自己回头”看”Encoder 的所有隐状态,按需取用。具体做法是,Decoder 当前时刻的隐状态作为”查询”,跟 Encoder 每个时刻的隐状态做”匹配”(通过一个小型前馈网络计算匹配分数),然后用匹配分数对 Encoder 的隐状态加权求和,得到一个”上下文向量”,作为当前解码步的额外输入。
这就是 Attention 的雏形:让模型学会”注意”输入序列的不同部分。不过 Bahdanau Attention 的匹配函数是一个额外的前馈网络(additive attention),计算效率不高。
1.2 Luong Attention (2015):简化计算
Luong 在 2015 年对 Bahdanau Attention 做了两个关键简化:一是提出了更高效的匹配函数(直接用点积代替前馈网络),二是探索了多种对齐方式(全局 vs 局部)。其中,点积 Attention 的计算方式——两个向量直接做内积来衡量相似度——后来成为了 Transformer 的基础。
点积的计算效率远高于前馈网络:它可以被表示为矩阵乘法,天然适合 GPU 并行加速。这个看似简单的改进,为后来 Self-Attention 的大规模应用铺平了道路。
1.3 Self-Attention (2017):Attention Is All You Need
2017 年的 Transformer 论文做了一个大胆的决定:完全抛弃 RNN,只用 Attention 来建模序列。
之前的 Attention 都是”跨序列”的——Decoder 去关注 Encoder。而 Self-Attention 是”自关注”——序列中的每个位置去关注同一个序列中的所有位置(包括自己)。这样做有两大优势:
- 并行性:RNN 必须逐步计算(第 $t$ 步依赖第 $t-1$ 步的结果),而 Self-Attention 中所有位置可以同时计算,天然适合 GPU 并行。
- 长距离依赖:RNN 中距离较远的两个词需要通过多步传递才能交互信息,信号会逐步衰减。Self-Attention 中任意两个位置之间只需要一步就能直接交互,理论上不存在距离衰减问题。
代价是什么?Self-Attention 的计算复杂度是 $O(N^2)$——序列中每对位置都需要计算匹配分数,而 RNN 的复杂度是 $O(N)$。这个平方复杂度成为后续无数优化工作的出发点。
2. 从信息检索角度理解 QKV
Attention 中最核心的三个概念是 Query(Q)、Key(K)和 Value(V)。这三个词直接借鉴了信息检索领域的术语,用一个日常场景来类比可以帮助建立直觉。
2.1 图书馆检索的类比
想象你走进一座图书馆,你脑子里有一个模糊的需求:”我想了解关于并行计算的内容”。这个需求就是你的 Query。
图书馆里每一本书的书脊上都贴着标签——“操作系统”、”计算机网络”、”并行计算导论”、”数据结构”等等。这些标签就是每本书的 Key。
你拿着自己的 Query,去和每本书的 Key 做比对。”并行计算导论”这个 Key 和你的 Query 高度匹配,匹配分数最高;”操作系统”可能有一些相关性,分数中等;”数据结构”和你的需求关系不大,分数很低。
确定了匹配分数之后,你不是把书脊标签(Key)拿走,而是把每本书的实际内容(Value)按照匹配分数加权汇总。”并行计算导论”的内容占大比重,”操作系统”的内容占一点,”数据结构”几乎不占——最终你得到的就是一份以并行计算为主、兼顾一点操作系统知识的综合信息。
2.2 三元组的形式化定义
回到 Self-Attention 的语境。给定输入序列 $X$(形状为 $(N, d_{model})$),三个线性变换将 $X$ 映射到不同的空间:
- **$Q = X \cdot W_Q$**:每个 token 生成自己的”查询向量”——“我需要什么信息”
- **$K = X \cdot W_K$**:每个 token 生成自己的”索引向量”——“我能提供什么信息的线索”
- **$V = X \cdot W_V$**:每个 token 生成自己的”内容向量”——“我实际携带的信息”
Q、K、V 之所以要从同一个 $X$ 做三次不同的线性变换,而不是直接用 $X$ 本身,原因在于解耦不同的角色。一个 token “需要什么”和它”能提供什么”往往是不同的。比如在一个句子里,动词可能需要关注它的主语和宾语(Query 的方向),但它作为被别人关注的对象时,提供的是动作语义(Key/Value 的方向)。三个独立的投影矩阵让模型有自由度去学习这些不同的映射。
3. Scaled Dot-Product Attention 详解
有了 Q、K、V 之后,Attention 的计算过程可以分解为清晰的五步。
3.1 第一步:线性投影
输入 $X$ 的形状为 $(N, d_{model})$。三个权重矩阵 $W_Q$、$W_K$、$W_V$ 的形状都是 $(d_{model}, d_{model})$(以单头为例):
$$
\begin{aligned}
Q &= X \cdot W_Q \quad & (N, d_{model}) \times (d_{model}, d_{model}) = (N, d_{model}) \
K &= X \cdot W_K \quad & (N, d_{model}) \times (d_{model}, d_{model}) = (N, d_{model}) \
V &= X \cdot W_V \quad & (N, d_{model}) \times (d_{model}, d_{model}) = (N, d_{model})
\end{aligned}
$$
这三次矩阵乘法是三次独立的 GEMM(General Matrix Multiply)操作。在实际实现中,为了提高 GPU 利用率,通常会把 $W_Q$、$W_K$、$W_V$ 合并成一个大矩阵 $W_{QKV}$(形状 $(d_{model}, 3 \times d_{model})$),做一次 GEMM 然后 split,这样能更好地利用 GPU 的算力。
3.2 第二步:计算原始注意力分数
用 $Q$ 和 $K$ 的转置做矩阵乘法,衡量每对 token 之间的匹配程度:
$$
S = Q K^\top \quad (N, d_{model}) \times (d_{model}, N) = (N, N)
$$
结果矩阵 $S$ 的每个元素 $S[i][j]$ 是第 $i$ 个 token 的查询向量和第 $j$ 个 token 的索引向量的点积,代表 token $i$ 对 token $j$ 的原始关注程度。
这一步产生了一个 $N \times N$ 的矩阵,这就是 Self-Attention 平方复杂度的根源。
3.3 第三步:缩放
将原始分数除以 $\sqrt{d_k}$($d_k$ 是 Key 向量的维度):
$$
S_{\text{scaled}} = \frac{S}{\sqrt{d_k}}
$$
为什么需要缩放?简单来说,当维度 $d_k$ 较大时,点积的结果会变得很大,导致后续 Softmax 的梯度趋近于零。详细的数学推导见第 6 节。
3.4 第四步:Softmax 归一化
对缩放后的分数矩阵按行做 Softmax,将原始分数转换为概率分布:
$$
A = \text{softmax}(S_{\text{scaled}}) \quad \in \mathbb{R}^{N \times N}, \text{ 每行和为 1}
$$
Softmax 的作用是双重的:一方面把任意实数映射到 $(0, 1)$ 区间,使其可以作为权重;另一方面保证每行的权重之和为 1,形成一个合法的概率分布。
3.5 第五步:加权求和
用注意力权重对 Value 矩阵做加权求和:
$$
\text{Output} = A \cdot V \quad (N, N) \times (N, d_{model}) = (N, d_{model})
$$
最终每个 token 位置得到一个 $d_{model}$ 维的向量,里面融合了它所”关注”的所有 token 的信息,关注程度由 $A$ 的权重决定。
3.6 完整公式
将以上步骤合并,Self-Attention 的计算可以用一行公式概括:
$$
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V
$$
最后还要再过一个输出投影矩阵 $W_O$(形状 $(d_{model}, d_{model})$),将结果映射回模型的隐藏空间。这个投影在 Multi-Head Attention 中尤其重要——它负责将多个头拼接后的表示重新混合。
4. PyTorch 手动实现
理论讲得再多,不如亲手写一遍代码。下面分别实现 Single-Head 和 Multi-Head Self-Attention。
4.1 Single-Head Self-Attention
1 | import torch |
4.2 Multi-Head Self-Attention
1 | class MultiHeadSelfAttention(nn.Module): |
上面的代码完整展示了从输入到输出的每一步。值得注意的是,.transpose(1, 2) 这一步将 (batch, seq_len, num_heads, head_dim) 变成 (batch, num_heads, seq_len, head_dim),这样后续的矩阵乘法就能在每个头内部独立进行,利用 batch 维度的并行性一次算完所有头。
5. 详细维度推导
抽象的维度符号容易让人迷糊。下面用一组具体数值,完整跟踪 Multi-Head Self-Attention 中每一步的张量形状。
5.1 参数设定
1 | seq_len = 6 # 序列长度(6 个 token) |
5.2 逐步跟踪
输入
1 | X: (1, 6, 512) # 1 个样本,6 个 token,每个 token 512 维 |
线性投影
1 | W_Q: (512, 512) # nn.Linear 的权重矩阵 |
切分多头 + 转置
1 | Q.view(1, 6, 8, 64): (1, 6, 8, 64) |
计算注意力分数
1 | K.transpose(-2, -1): (1, 8, 64, 6) # 转置最后两维 |
缩放
1 | scores / sqrt(64): (1, 8, 6, 6) # sqrt(64) = 8,每个分数除以 8 |
Softmax
1 | softmax(scores, dim=-1): (1, 8, 6, 6) # 每行(最后一维)归一化为概率分布 |
加权求和
1 | attn_weights @ V: (1, 8, 6, 6) @ (1, 8, 6, 64) = (1, 8, 6, 64) |
多头拼接
1 | transpose(1, 2): (1, 6, 8, 64) # 把 head 维度挪回去 |
输出投影
1 | W_O: (512, 512) |
最终输出
1 | Output: (1, 6, 512) # 与输入 X 形状完全一致 |
从头到尾,输入和输出的形状都是 $(batch, seq_len, d_{model})$,这保证了多个 Transformer Block 可以像积木一样层层堆叠。
6. 为什么要除以 $\sqrt{d_k}$:方差稳定性的数学推导
除以 $\sqrt{d_k}$ 这一步在直觉上容易被当作”调参技巧”一笔带过,但它背后有严格的数学原因。
6.1 问题提出
假设 $Q$ 和 $K$ 的每个分量都是均值为 0、方差为 1 的独立随机变量。我们来计算点积 $q \cdot k$ 的方差。
设 $q = (q_1, q_2, \ldots, q_{d_k})$,$k = (k_1, k_2, \ldots, k_{d_k})$,其中每个 $q_i$ 和 $k_j$ 独立同分布,满足 $E[q_i] = 0$,$\text{Var}(q_i) = 1$。
点积的定义是:
$$
q \cdot k = \sum_{i=1}^{d_k} q_i \cdot k_i
$$
6.2 推导过程
首先,单个分量乘积的期望和方差:
$$
E[q_i \cdot k_i] = E[q_i] \cdot E[k_i] = 0 \times 0 = 0
$$
$$
\text{Var}(q_i \cdot k_i) = E[q_i^2 \cdot k_i^2] - (E[q_i \cdot k_i])^2
$$
由于 $q_i$ 和 $k_i$ 独立:
$$
E[q_i^2 \cdot k_i^2] = E[q_i^2] \cdot E[k_i^2] = \text{Var}(q_i) \cdot \text{Var}(k_i) = 1 \times 1 = 1
$$
所以:
$$
\text{Var}(q_i \cdot k_i) = 1 - 0 = 1
$$
点积是 $d_k$ 个独立随机变量之和,根据方差的可加性:
$$
\text{Var}(q \cdot k) = \sum_{i=1}^{d_k} \text{Var}(q_i \cdot k_i) = d_k
$$
6.3 结论
点积 $q \cdot k$ 的方差是 $d_k$。这意味着当 $d_k = 64$ 时,点积值的标准差大约是 8;当 $d_k = 128$ 时,标准差大约是 11.3。维度越大,点积值的绝对值越大,分布越分散。
Softmax 函数 $\text{softmax}(z_i) = \exp(z_i) / \sum \exp(z_j)$ 对输入值的量级非常敏感。当输入值之间的差距很大时(比如一个值是 50,另一个是 -10),Softmax 的输出会极度集中在最大值上,接近 one-hot 分布。此时梯度几乎为零,参数无法更新,训练陷入停滞。
除以 $\sqrt{d_k}$ 之后:
$$
\text{Var}\left(\frac{q \cdot k}{\sqrt{d_k}}\right) = \frac{\text{Var}(q \cdot k)}{d_k} = \frac{d_k}{d_k} = 1
$$
方差被拉回到 1,无论 $d_k$ 取什么值,Softmax 的输入始终在一个合理的范围内波动,梯度保持健康。这就是 Scaled Dot-Product Attention 名字中 “Scaled” 一词的由来。
7. Softmax 的数值稳定性
Softmax 看起来是一个简单的公式,但在实际的 GPU 实现中,它隐藏着一个容易导致数值溢出的陷阱。
7.1 溢出问题
标准 Softmax 公式:
$$
\text{softmax}(z_i) = \frac{e^{z_i}}{\sum_{j=1}^{N} e^{z_j}}
$$
问题在于 $\exp$ 函数增长极快。当 $z_i$ 值较大时(比如 $z_i = 1000$),$\exp(1000)$ 直接超出 float32 甚至 float16 的表示范围,得到 inf。即使分子分母同时溢出,inf / inf 会得到 NaN,整个计算就废了。
7.2 减最大值技巧
解决方案是利用 Softmax 的平移不变性:对所有输入减去最大值,不改变结果。
$$
\text{softmax}(z_i) = \frac{e^{z_i - \max(z)}}{\sum_{j=1}^{N} e^{z_j - \max(z)}}
$$
数学证明很简单——分子分母同乘以 $\exp(-\max(z))$,等价于分子分母各自除以 $\exp(\max(z))$,值不变。减去最大值后,最大的指数输入是 0,所以 $\exp$ 的结果最大为 1,不会溢出。
但这带来了一个效率问题:标准实现需要对数据做三遍扫描:
- 第一遍:遍历所有元素,找最大值 $m$
- 第二遍:遍历所有元素,计算 $\exp(z_i - m)$ 和它们的总和 $\text{sum}$
- 第三遍:遍历所有元素,每个 $\exp(z_i - m)$ 除以 $\text{sum}$
每遍扫描都要从 HBM(高带宽显存)读取数据,三遍就是三次 HBM 读取。当 $N$ 很大时(比如 128K 的序列长度),这个 IO 开销非常可观。
7.3 Online Softmax
Milakov 和 Gimelshein 在 2018 年提出了 Online Softmax 算法,将三遍扫描合并为一遍扫描,核心思想是在遍历数据的同时动态维护最大值和归一化分母。
算法流程如下。在遍历到第 $i$ 个元素时,维护两个变量:
- $m_i$:前 $i$ 个元素的最大值
- $d_i$:前 $i$ 个元素的指数和(以 $m_i$ 为基准)
递推公式:
$$
\begin{aligned}
m_i &= \max(m_{i-1},; z_i) \
d_i &= d_{i-1} \cdot e^{m_{i-1} - m_i} + e^{z_i - m_i}
\end{aligned}
$$
关键技巧在第二行:当最大值从 $m_{i-1}$ 更新到 $m_i$ 时,之前累积的指数和 $d_{i-1}$ 需要乘以一个修正因子 $\exp(m_{i-1} - m_i)$ 来”换基”。这样只需一遍扫描就能同时得到最大值和指数和,然后再做一遍扫描完成归一化——总共两遍,比标准实现少一遍。
这个思想对 FlashAttention 至关重要:FlashAttention 将 Attention 矩阵分成小块(tile)逐块处理,每处理一块就需要更新 Softmax 的中间结果。如果不能在线更新 Softmax,就必须把整个 $N \times N$ 矩阵写入 HBM 后再做全局 Softmax,那就失去了分块计算的意义。Online Softmax 让分块计算和精确 Softmax 得以兼容。
8. Multi-Head Attention
8.1 为什么需要多头
单头 Attention 只有一组 QKV 投影,意味着模型只能学习一种”关注模式”。但语言中 token 之间的关系是多维度的——同一个词和其他词之间可能同时存在句法关系(主谓一致)、语义关系(同义替换)、位置关系(相邻词的局部模式)等。
打个比方:在一次项目评审会议上,只派一个评审员去审阅整个项目,他只能从自己擅长的角度提出意见。如果派出一个评审团——一位看技术架构,一位看代码质量,一位看测试覆盖率,一位看文档完整性——每个人独立给出评分和建议,最后汇总成一份综合评审报告,覆盖面就远比单人评审要全面得多。
Multi-Head Attention 就是这个”评审团”机制。每个头有自己独立的 $W_Q$、$W_K$、$W_V$ 投影参数,在不同的子空间中捕捉不同类型的关系。
8.2 数学原理
假设 $d_{model} = 512$,$num_heads = 8$,则 $head_dim = 64$。
Multi-Head Attention 的完整公式:
$$
\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h) \cdot W_O
$$
$$
\text{head}_i = \text{Attention}(Q \cdot W_Q^i, K \cdot W_K^i, V \cdot W_V^i)
$$
其中每个头的投影矩阵 $W_Q^i$、$W_K^i$、$W_V^i$ 的形状是 $(d_{model}, head_dim) = (512, 64)$。但实际实现中,并不会真的维护 $h$ 组小矩阵——而是用一个大矩阵 $W_Q$(形状 $(512, 512)$)做一次投影,然后 reshape 成 $(seq_len, 8, 64)$ 来切分。这样做是等价的:大矩阵可以看作 8 个小矩阵纵向拼接。
8.3 参数量分析
以 $d_{model} = 512$ 为例:
| 参数矩阵 | 形状 | 参数量 |
|---|---|---|
| $W_Q$ | $(512, 512)$ | 262,144 |
| $W_K$ | $(512, 512)$ | 262,144 |
| $W_V$ | $(512, 512)$ | 262,144 |
| $W_O$ | $(512, 512)$ | 262,144 |
| 合计 | $4 \times 512^2 = 1{,}048{,}576$ |
通用公式:MHA 的参数量为 **$4 \times d_{model}^2$**(不算 bias 的情况)。不管有多少个头,总参数量不变——头数只影响切分方式,不影响总参数。这是因为每增加一个头,每个头的维度相应减小,两者乘积(即总投影维度)始终等于 $d_{model}$。
8.4 多头的工程意义
多头结构天然适合并行化。8 个头的计算完全独立,可以:
- GPU 内并行:利用 batch 维度,在一次 CUDA kernel 启动中同时处理所有头
- 多卡张量并行(Tensor Parallelism):将不同头分配到不同 GPU 上,每张 GPU 只计算自己负责的若干头。比如 8 个头分到 4 张 GPU,每张处理 2 个头。最后通过一次 AllReduce 通信汇总输出投影的结果
9. MHA vs MQA vs GQA vs MLA
随着大模型进入推理效率至上的时代,Attention 头的结构也在不断演化。核心驱动力是一个问题:KV Cache 太大了。
9.1 MHA (Multi-Head Attention)
标准的 MHA 中,每个注意力头都有独立的 Q、K、V。如果有 $h$ 个头,那就有 $h$ 组 K 和 $h$ 组 V 需要缓存。
- $h$ 个 Q 头,$h$ 个 K 头,$h$ 个 V 头
- 每组 Q/K/V 独立,互不共享
9.2 MQA (Multi-Query Attention)
Shazeer 在 2019 年提出:所有 Q 头共享同一组 K 和 V。也就是说,只有 1 个 K 头和 1 个 V 头,但仍然有 $h$ 个 Q 头。
- $h$ 个 Q 头,1 个 K 头,1 个 V 头
- 所有 Q 头共享同一份 K 和 V
好处是 KV Cache 缩小为原来的 $1/h$,推理速度大幅提升。代价是模型表达能力下降——所有头被迫从同一组 KV 中提取信息,多样性受限。
9.3 GQA (Grouped-Query Attention)
GQA 是 MHA 和 MQA 的折中方案(Ainslie et al., 2023)。将 $h$ 个 Q 头分成 $g$ 个组(每组 $h/g$ 个 Q 头),每组共享一组 K 和 V。
- $h$ 个 Q 头,$g$ 个 K 头,$g$ 个 V 头
- 每 $h/g$ 个 Q 头共享一组 K 和 V
- 当 $g = h$ 时退化为 MHA,当 $g = 1$ 时退化为 MQA
GQA 在模型质量和推理效率之间取得了更好的平衡。LLaMA-2-70B、Mistral-7B 等主流模型都采用了 GQA。
9.4 对比表格
以 $d_{model} = 4096$,$num_heads = 32$,$head_dim = 128$ 为例,假设序列长度 $N = 4096$,FP16 精度:
| 指标 | MHA ($g=32$) | GQA ($g=8$) | MQA ($g=1$) |
|---|---|---|---|
| Q 头数 | 32 | 32 | 32 |
| KV 头数 | 32 | 8 | 1 |
| $W_K$ 参数量 | $32 \times 128 \times 4096 = 16\text{M}$ | $8 \times 128 \times 4096 = 4\text{M}$ | $1 \times 128 \times 4096 = 0.5\text{M}$ |
| $W_V$ 参数量 | $16\text{M}$ | $4\text{M}$ | $0.5\text{M}$ |
| 单 token KV Cache | $2 \times 32 \times 128 \times 2\text{B} = 16\text{KB}$ | $2 \times 8 \times 128 \times 2\text{B} = 4\text{KB}$ | $2 \times 1 \times 128 \times 2\text{B} = 0.5\text{KB}$ |
| 4096 token KV Cache / 层 | $64\text{MB}$ | $16\text{MB}$ | $2\text{MB}$ |
| 模型质量 | 最好 | 接近 MHA | 有下降 |
可以看到,从 MHA 到 GQA($g=8$),KV Cache 缩小为 1/4,参数量减少,但模型质量几乎不受影响。这就是 GQA 成为主流选择的原因。
9.5 MLA (Multi-Latent Attention)
DeepSeek-V2 提出了 MLA(Multi-Latent Attention),代表了另一种压缩 KV Cache 的思路。
MLA 的核心思想是:不直接缓存完整的 K 和 V 向量,而是将它们压缩到一个低维潜在空间(latent space),只缓存这个低维表示。推理时再将低维表示解压回完整的 K 和 V。
- 标准 MHA:$X \to W_K \to K$(缓存 $K$,维度 = num_kv_heads $\times$ head_dim)
- MLA:$X \to W_{DKV} \to c$(缓存 $c$,维度远小于 $K$ 的维度),推理时 $c \to W_{UK} \to K$ 按需解压
MLA 通过低秩压缩将 KV Cache 大幅缩小(DeepSeek-V2 报告了约 93.3% 的压缩率),同时通过精心设计的上投影矩阵保持了模型质量。它与 GQA 的思路不同——GQA 是减少 KV 头的数量,MLA 是降低每个表示的维度——但目标一致:让推理时的 KV Cache 尽可能小。
10. Causal Mask 的作用与实现
10.1 为什么需要遮蔽
在自回归语言模型(如 GPT 系列、LLaMA 系列)中,模型的训练目标是”根据前文预测下一个 token”。这要求在计算 Attention 时,每个位置的 token 只能看到自己和它之前的 token,不能”偷看”未来的信息。否则,模型在训练时就能看到答案,学不到任何有意义的预测能力。
10.2 实现方式
Causal Mask(因果遮蔽)是一个上三角矩阵,覆盖在注意力分数矩阵上,将未来位置的分数设为负无穷:
1 | # 构造因果遮蔽矩阵 |
Causal Mask 的另一个工程意义在于:它使得 Attention 矩阵只有下三角部分有效,理论上可以跳过上三角部分的计算。FlashAttention-2 正是利用了这一点——在处理上三角区域的 tile 时直接跳过,减少了接近一半的计算量。
11. Self-Attention 的计算瓶颈分析
理解 Self-Attention 的性能瓶颈,需要区分两种不同的场景。
11.1 计算量分析
Self-Attention 中有四次主要的矩阵乘法:
| 运算 | 形状 | FLOPs |
|---|---|---|
| $Q = X \cdot W_Q$ | $(N, d) \times (d, d)$ | $2Nd^2$ |
| $K = X \cdot W_K$ | $(N, d) \times (d, d)$ | $2Nd^2$ |
| $V = X \cdot W_V$ | $(N, d) \times (d, d)$ | $2Nd^2$ |
| $S = Q \cdot K^T$ | $(N, d) \times (d, N)$ | $2N^2d$ |
| $O = A \cdot V$ | $(N, N) \times (N, d)$ | $2N^2d$ |
| $Final = O \cdot W_O$ | $(N, d) \times (d, d)$ | $2Nd^2$ |
总计:$8Nd^2 + 4N^2d$。
当 $N$ 较小(比如 $N < d$)时,$8Nd^2$ 项主导——瓶颈在 QKV 投影的 GEMM 运算。当 $N$ 较大(比如 $N > d$)时,$4N^2d$ 项主导——瓶颈在 Attention 矩阵的计算。
11.2 Compute Bound vs Memory Bound
GPU 的性能受两个指标限制:
- 算力(Compute):每秒能做多少次浮点运算(FLOPS)
- 带宽(Memory Bandwidth):每秒能从 HBM 读写多少数据(GB/s)
一个运算是 Compute Bound 还是 Memory Bound,取决于它的算术强度(Arithmetic Intensity)= FLOPs / Bytes。
- 如果算术强度 > GPU 的算力/带宽比值,运算是 Memory Bound——GPU 算力用不满,大部分时间在等数据搬运
- 如果算术强度 < GPU 的算力/带宽比值,运算是 Compute Bound——GPU 带宽够用,大部分时间在做计算
以 A100 为例:FP16 算力约 312 TFLOPS,HBM 带宽约 2 TB/s,算力/带宽比值约 156 FLOP/Byte。
Prefill 阶段(处理完整 prompt):
QKV 投影是大矩阵乘法,batch 维度(seq_len)很大,算术强度高,通常是 Compute Bound。Attention 的 $Q \cdot K^T$ 也是大矩阵乘,同样是 Compute Bound。
Decode 阶段(逐 token 生成):
每步只有 1 个 token,QKV 投影退化为矩阵-向量乘法(GEMV),算术强度极低,是 Memory Bound。Attention 退化为一个向量和整个 KV Cache 的运算,同样 Memory Bound。大部分时间花在从 HBM 搬运模型权重和 KV Cache 上。
11.3 $O(N^2)$ 的显存瓶颈
标准 Attention 实现需要显式地计算并存储完整的 $N \times N$ 注意力矩阵。以 $seq_len = 128\text{K}$、$num_heads = 32$ 为例:
$$
32 \times 128\text{K} \times 128\text{K} \times 2\text{B} = 32 \times 16{,}384\text{M} \times 2 = 1{,}048{,}576\text{ MB} = 1\text{ TB}
$$
这显然远超任何单张 GPU 的显存容量。即使序列长度”只有”8K,注意力矩阵也需要 $32 \times 8\text{K} \times 8\text{K} \times 2\text{B} = 4\text{ GB}$,已经是一个不可忽视的开销。
这个 $O(N^2)$ 的显存瓶颈,正是 FlashAttention 要解决的核心问题。
12. FlashAttention 的核心思想
FlashAttention(Dao et al., 2022)是过去几年 Attention 优化领域最具影响力的工作。它不改变 Attention 的计算结果(是精确计算,不是近似),但通过重新编排计算顺序,大幅减少了对 HBM 的访问量。
12.1 GPU 存储层次回顾
要理解 FlashAttention,需要先了解 GPU 的存储层次:
- HBM(High Bandwidth Memory):GPU 的主显存,容量大(如 A100 的 80 GB)但访问速度相对慢(2 TB/s)
- SRAM(片上缓存):每个 SM(Streaming Multiprocessor)上的共享内存和寄存器,容量很小(如 A100 每个 SM 约 192 KB 共享内存,全部 SM 合计约 20 MB)但访问速度极快(约 19 TB/s)
标准 Attention 的流程是:在 HBM 中完成 $Q \cdot K^T$,把完整的 $N \times N$ 注意力矩阵写回 HBM;然后从 HBM 读取注意力矩阵做 Softmax,结果写回 HBM;最后从 HBM 读取 Softmax 结果和 $V$ 做矩阵乘法。每一步都涉及对 $N \times N$ 大矩阵的 HBM 读写——这是巨大的带宽浪费。
12.2 Tiling:分块计算
FlashAttention 的第一个关键技术是 Tiling(分块):不一次性计算完整的 $N \times N$ 注意力矩阵,而是将 Q、K、V 分成若干小块(tile),每次只加载一小块 Q 和一小块 K、V 到 SRAM 中,在 SRAM 内完成该块的 Attention 计算,然后将结果写回 HBM。
1 | 标准实现: |
每次只处理一个 tile,注意力矩阵的这一小块始终驻留在 SRAM 中,永远不需要把完整的 $N \times N$ 矩阵写入 HBM。
12.3 Online Softmax 的关键作用
Tiling 带来一个棘手的问题:Softmax 需要对每一行的所有元素做归一化(需要全局最大值和全局求和),但分块计算时每次只看到一行的一部分——怎么在只看到局部数据的情况下计算全局 Softmax?
这就是 Online Softmax 派上用场的地方。回忆第 7.3 节的 Online Softmax 递推公式:当新数据块到来时,可以动态更新全局最大值和归一化分母,并修正之前块的计算结果。
具体来说,处理 $Q_i$ 对 $K_1$ 块的 Attention 后,得到一个局部输出 $O_1$ 和对应的 Softmax 统计量(局部最大值 $m_1$ 和局部分母 $l_1$)。当继续处理 $Q_i$ 对 $K_2$ 块时,会得到新的局部统计量 $m_2$ 和 $l_2$。通过 Online Softmax 的修正:
$$
\begin{aligned}
m_{\text{new}} &= \max(m_1, m_2) \
l_{\text{new}} &= l_1 \cdot e^{m_1 - m_{\text{new}}} + l_2 \cdot e^{m_2 - m_{\text{new}}} \
O_{\text{new}} &= \frac{O_1 \cdot l_1 \cdot e^{m_1 - m_{\text{new}}} + O_2^{\text{local}} \cdot l_2 \cdot e^{m_2 - m_{\text{new}}}}{l_{\text{new}}}
\end{aligned}
$$
不断累积,直到所有 $K$ 块处理完毕,最终结果与标准 Attention 的全局 Softmax 在数学上完全一致。
12.4 为什么能减少 HBM 访问
标准 Attention 的 HBM 访问量:
- 写入 $S$ ($N \times N$):$O(N^2)$ 次写
- 读取 $S$ 做 Softmax:$O(N^2)$ 次读
- 写入 $P = \text{Softmax}(S)$:$O(N^2)$ 次写
- 读取 $P$ 做 $P \cdot V$:$O(N^2)$ 次读
- **总 HBM 访问量:$O(N^2)$**(加上 Q, K, V 本身的 $O(Nd)$ 读取)
FlashAttention 的 HBM 访问量:
- 读取 Q, K, V:$O(Nd)$
- 写入最终输出 O:$O(Nd)$
- 中间的注意力矩阵始终在 SRAM 中,不写入 HBM
- 总 HBM 访问量:$O(Nd)$
从 $O(N^2)$ 降到 $O(Nd)$——当 $N$ 远大于 $d$ 时(长序列场景),这是一个数量级的改进。注意,计算量(FLOPs)没有变——仍然是 $O(N^2d)$——改变的只是 IO 模式。FlashAttention 的加速本质上来自于减少了对慢速 HBM 的访问,让计算尽可能在快速 SRAM 中完成。
12.5 FlashAttention-2 和 FlashAttention-3
FlashAttention-2(Dao, 2023)在 FlashAttention 的基础上进一步优化了并行策略和 warp 调度:
- 调整了内外循环的顺序(外循环遍历 Q 块,内循环遍历 K/V 块),减少了共享内存的读写
- 更好的 warp 级工作分配,提升了 GPU 的占用率
- 利用 Causal Mask 跳过全零的 tile,进一步减少计算量
FlashAttention-3(Dao et al., 2024)则针对 Hopper 架构(H100)进行了优化,利用了 TMA(Tensor Memory Accelerator)和 wgmma(Warp Group MMA)指令,进一步提升了硬件利用率。
目前 FlashAttention 已经是所有主流推理和训练框架的标配——PyTorch 2.0+ 内置了 torch.nn.functional.scaled_dot_product_attention,底层默认调用 FlashAttention kernel。
🎯 自我检验清单
完成本文学习后,用以下问题检验自己的理解深度:
- 能从 Bahdanau Attention 到 Self-Attention 梳理出 Attention 机制的演进脉络,说清每一步解决了什么问题
- 能默写 Scaled Dot-Product Attention 的完整公式,并解释 Q、K、V 三元组各自的作用
- 能从零推导 “除以 $\sqrt{d_k}$” 的数学原因:假设 $q_i, k_i$ 独立标准正态,推出点积方差为 $d_k$
- 能手写 Single-Head 和 Multi-Head Self-Attention 的 PyTorch 实现,能说清每一步的张量维度变化
- 能解释 Softmax 为什么要减最大值,以及 Online Softmax 的核心递推思想
- 给定 $seq_len$、$d_{model}$、$num_heads$ 的具体数值,能完整跟踪每一步的张量形状
- 能画出 MHA / MQA / GQA 的结构差异,并估算各自的 KV Cache 大小
- 能说清 Causal Mask 的作用和实现方式(上三角设为负无穷,Softmax 后归零)
- 能区分 Compute Bound 和 Memory Bound,并说清 Prefill 和 Decode 阶段各自的瓶颈类型
- 能说清 FlashAttention 的核心思想:Tiling + Online Softmax 如何将 HBM 访问从 $O(N^2)$ 降到 $O(Nd)$
📚 参考资料
论文
- Neural Machine Translation by Jointly Learning to Align and Translate (Bahdanau et al., 2014):https://arxiv.org/abs/1409.0473 – Bahdanau Attention,注意力机制的开山之作
- Effective Approaches to Attention-based Neural Machine Translation (Luong et al., 2015):https://arxiv.org/abs/1508.04025 – Luong Attention,引入点积注意力
- Attention Is All You Need (Vaswani et al., 2017):https://arxiv.org/abs/1706.03762 – Transformer 原始论文
- Fast Transformer Decoding: One Write-Head is All You Need (Shazeer, 2019):https://arxiv.org/abs/1911.02150 – Multi-Query Attention
- GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints (Ainslie et al., 2023):https://arxiv.org/abs/2305.13245 – Grouped-Query Attention
- DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model (DeepSeek-AI, 2024):https://arxiv.org/abs/2405.04434 – Multi-Latent Attention
- FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness (Dao et al., 2022):https://arxiv.org/abs/2205.14135 – FlashAttention v1
- FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning (Dao, 2023):https://arxiv.org/abs/2307.08691 – FlashAttention v2
- Online Normalizer Calculation for Softmax (Milakov & Gimelshein, 2018):https://arxiv.org/abs/1805.02867 – Online Softmax 算法
教程与博客
- The Illustrated Transformer (Jay Alammar):https://jalammar.github.io/illustrated-transformer/ – 图文并茂的 Transformer 入门
- The Annotated Transformer (Harvard NLP):https://nlp.seas.harvard.edu/annotated-transformer/ – 论文逐行对应 PyTorch 实现
- Andrej Karpathy: Let’s build GPT from scratch:https://www.youtube.com/watch?v=kCc8FmEb1nY – 从零手写 GPT
- ELI5: FlashAttention (Aleksa Gordic):https://gordicaleksa.medium.com/eli5-flash-attention-5c44017022ad – FlashAttention 通俗讲解