2.3 Self-Attention机制深入理解

Self-Attention 是 Transformer 的心脏,也是当代大模型中计算量最集中、优化手段最丰富的模块。无论你是想理解 FlashAttention 背后的 IO 优化思想,还是想搞清楚 GQA、MLA 这些 Attention 变种为什么能减少推理开销,都绕不开对 Self-Attention 机制的深入理解。本文将从 Attention 的历史起源讲起,逐步拆解 Scaled Dot-Product Attention 的每一步数学原理,手写 PyTorch 实现,分析计算瓶颈,最后延伸到 FlashAttention 和各种 Attention 变种,力求让读者建立从直觉到公式再到工程实现的完整认知链条。

📑 目录


1. Attention 的历史演进

在 2017 年 Transformer 横空出世之前,Attention 机制已经在序列建模领域酝酿了好几年。理解它的演进脉络,有助于我们把握 Self-Attention 到底解决了什么问题,以及为什么它的设计长成今天这个样子。

1.1 Bahdanau Attention (2014):开山之作

传统的 Encoder-Decoder 架构(比如用 RNN 做机器翻译)有一个致命缺陷:Encoder 把整个输入序列压缩成一个固定长度的向量,再交给 Decoder 去生成输出。当输入序列很长时,这个固定向量根本装不下所有信息,翻译质量会急剧下降。

Bahdanau 等人在 2014 年提出了一个关键改进:别硬压缩了,让 Decoder 在生成每个词的时候,自己回头”看”Encoder 的所有隐状态,按需取用。具体做法是,Decoder 当前时刻的隐状态作为”查询”,跟 Encoder 每个时刻的隐状态做”匹配”(通过一个小型前馈网络计算匹配分数),然后用匹配分数对 Encoder 的隐状态加权求和,得到一个”上下文向量”,作为当前解码步的额外输入。

这就是 Attention 的雏形:让模型学会”注意”输入序列的不同部分。不过 Bahdanau Attention 的匹配函数是一个额外的前馈网络(additive attention),计算效率不高。

1.2 Luong Attention (2015):简化计算

Luong 在 2015 年对 Bahdanau Attention 做了两个关键简化:一是提出了更高效的匹配函数(直接用点积代替前馈网络),二是探索了多种对齐方式(全局 vs 局部)。其中,点积 Attention 的计算方式——两个向量直接做内积来衡量相似度——后来成为了 Transformer 的基础。

点积的计算效率远高于前馈网络:它可以被表示为矩阵乘法,天然适合 GPU 并行加速。这个看似简单的改进,为后来 Self-Attention 的大规模应用铺平了道路。

1.3 Self-Attention (2017):Attention Is All You Need

2017 年的 Transformer 论文做了一个大胆的决定:完全抛弃 RNN,只用 Attention 来建模序列

之前的 Attention 都是”跨序列”的——Decoder 去关注 Encoder。而 Self-Attention 是”自关注”——序列中的每个位置去关注同一个序列中的所有位置(包括自己)。这样做有两大优势:

  • 并行性:RNN 必须逐步计算(第 $t$ 步依赖第 $t-1$ 步的结果),而 Self-Attention 中所有位置可以同时计算,天然适合 GPU 并行。
  • 长距离依赖:RNN 中距离较远的两个词需要通过多步传递才能交互信息,信号会逐步衰减。Self-Attention 中任意两个位置之间只需要一步就能直接交互,理论上不存在距离衰减问题。

代价是什么?Self-Attention 的计算复杂度是 $O(N^2)$——序列中每对位置都需要计算匹配分数,而 RNN 的复杂度是 $O(N)$。这个平方复杂度成为后续无数优化工作的出发点。


2. 从信息检索角度理解 QKV

Attention 中最核心的三个概念是 Query(Q)、Key(K)和 Value(V)。这三个词直接借鉴了信息检索领域的术语,用一个日常场景来类比可以帮助建立直觉。

2.1 图书馆检索的类比

想象你走进一座图书馆,你脑子里有一个模糊的需求:”我想了解关于并行计算的内容”。这个需求就是你的 Query

图书馆里每一本书的书脊上都贴着标签——“操作系统”、”计算机网络”、”并行计算导论”、”数据结构”等等。这些标签就是每本书的 Key

你拿着自己的 Query,去和每本书的 Key 做比对。”并行计算导论”这个 Key 和你的 Query 高度匹配,匹配分数最高;”操作系统”可能有一些相关性,分数中等;”数据结构”和你的需求关系不大,分数很低。

确定了匹配分数之后,你不是把书脊标签(Key)拿走,而是把每本书的实际内容(Value)按照匹配分数加权汇总。”并行计算导论”的内容占大比重,”操作系统”的内容占一点,”数据结构”几乎不占——最终你得到的就是一份以并行计算为主、兼顾一点操作系统知识的综合信息。

2.2 三元组的形式化定义

回到 Self-Attention 的语境。给定输入序列 $X$(形状为 $(N, d_{model})$),三个线性变换将 $X$ 映射到不同的空间:

  • **$Q = X \cdot W_Q$**:每个 token 生成自己的”查询向量”——“我需要什么信息”
  • **$K = X \cdot W_K$**:每个 token 生成自己的”索引向量”——“我能提供什么信息的线索”
  • **$V = X \cdot W_V$**:每个 token 生成自己的”内容向量”——“我实际携带的信息”

Q、K、V 之所以要从同一个 $X$ 做三次不同的线性变换,而不是直接用 $X$ 本身,原因在于解耦不同的角色。一个 token “需要什么”和它”能提供什么”往往是不同的。比如在一个句子里,动词可能需要关注它的主语和宾语(Query 的方向),但它作为被别人关注的对象时,提供的是动作语义(Key/Value 的方向)。三个独立的投影矩阵让模型有自由度去学习这些不同的映射。


3. Scaled Dot-Product Attention 详解

有了 Q、K、V 之后,Attention 的计算过程可以分解为清晰的五步。

3.1 第一步:线性投影

输入 $X$ 的形状为 $(N, d_{model})$。三个权重矩阵 $W_Q$、$W_K$、$W_V$ 的形状都是 $(d_{model}, d_{model})$(以单头为例):

$$
\begin{aligned}
Q &= X \cdot W_Q \quad & (N, d_{model}) \times (d_{model}, d_{model}) = (N, d_{model}) \
K &= X \cdot W_K \quad & (N, d_{model}) \times (d_{model}, d_{model}) = (N, d_{model}) \
V &= X \cdot W_V \quad & (N, d_{model}) \times (d_{model}, d_{model}) = (N, d_{model})
\end{aligned}
$$

这三次矩阵乘法是三次独立的 GEMM(General Matrix Multiply)操作。在实际实现中,为了提高 GPU 利用率,通常会把 $W_Q$、$W_K$、$W_V$ 合并成一个大矩阵 $W_{QKV}$(形状 $(d_{model}, 3 \times d_{model})$),做一次 GEMM 然后 split,这样能更好地利用 GPU 的算力。

3.2 第二步:计算原始注意力分数

用 $Q$ 和 $K$ 的转置做矩阵乘法,衡量每对 token 之间的匹配程度:

$$
S = Q K^\top \quad (N, d_{model}) \times (d_{model}, N) = (N, N)
$$

结果矩阵 $S$ 的每个元素 $S[i][j]$ 是第 $i$ 个 token 的查询向量和第 $j$ 个 token 的索引向量的点积,代表 token $i$ 对 token $j$ 的原始关注程度。

这一步产生了一个 $N \times N$ 的矩阵,这就是 Self-Attention 平方复杂度的根源。

3.3 第三步:缩放

将原始分数除以 $\sqrt{d_k}$($d_k$ 是 Key 向量的维度):

$$
S_{\text{scaled}} = \frac{S}{\sqrt{d_k}}
$$

为什么需要缩放?简单来说,当维度 $d_k$ 较大时,点积的结果会变得很大,导致后续 Softmax 的梯度趋近于零。详细的数学推导见第 6 节。

3.4 第四步:Softmax 归一化

对缩放后的分数矩阵按行做 Softmax,将原始分数转换为概率分布:

$$
A = \text{softmax}(S_{\text{scaled}}) \quad \in \mathbb{R}^{N \times N}, \text{ 每行和为 1}
$$

Softmax 的作用是双重的:一方面把任意实数映射到 $(0, 1)$ 区间,使其可以作为权重;另一方面保证每行的权重之和为 1,形成一个合法的概率分布。

3.5 第五步:加权求和

用注意力权重对 Value 矩阵做加权求和:

$$
\text{Output} = A \cdot V \quad (N, N) \times (N, d_{model}) = (N, d_{model})
$$

最终每个 token 位置得到一个 $d_{model}$ 维的向量,里面融合了它所”关注”的所有 token 的信息,关注程度由 $A$ 的权重决定。

3.6 完整公式

将以上步骤合并,Self-Attention 的计算可以用一行公式概括:

$$
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V
$$

最后还要再过一个输出投影矩阵 $W_O$(形状 $(d_{model}, d_{model})$),将结果映射回模型的隐藏空间。这个投影在 Multi-Head Attention 中尤其重要——它负责将多个头拼接后的表示重新混合。


4. PyTorch 手动实现

理论讲得再多,不如亲手写一遍代码。下面分别实现 Single-Head 和 Multi-Head Self-Attention。

4.1 Single-Head Self-Attention

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class SingleHeadSelfAttention(nn.Module):
"""手动实现单头 Self-Attention,不依赖 PyTorch 内置的 MHA 模块"""
def __init__(self, d_model: int):
super().__init__()
self.d_model = d_model
# 三个线性投影:输入维度 d_model,输出维度 d_model
self.W_Q = nn.Linear(d_model, d_model, bias=False)
self.W_K = nn.Linear(d_model, d_model, bias=False)
self.W_V = nn.Linear(d_model, d_model, bias=False)
# 输出投影
self.W_O = nn.Linear(d_model, d_model, bias=False)

def forward(self, x: torch.Tensor, mask: torch.Tensor = None):
# x: (batch, seq_len, d_model)
# Step 1: 线性投影得到 Q, K, V
Q = self.W_Q(x) # (batch, seq_len, d_model)
K = self.W_K(x) # (batch, seq_len, d_model)
V = self.W_V(x) # (batch, seq_len, d_model)

# Step 2: 计算注意力分数 = Q @ K^T
scores = torch.matmul(Q, K.transpose(-2, -1)) # (batch, seq_len, seq_len)

# Step 3: 缩放,防止点积值过大导致 softmax 梯度消失
scores = scores / math.sqrt(self.d_model)

# Step 4: 如果提供了 mask,将被遮蔽位置的分数设为负无穷
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))

# Step 5: Softmax 归一化,得到注意力权重
attn_weights = F.softmax(scores, dim=-1) # (batch, seq_len, seq_len)

# Step 6: 用权重对 V 加权求和
context = torch.matmul(attn_weights, V) # (batch, seq_len, d_model)

# Step 7: 输出投影
output = self.W_O(context) # (batch, seq_len, d_model)
return output

4.2 Multi-Head Self-Attention

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
class MultiHeadSelfAttention(nn.Module):
"""手动实现多头 Self-Attention:将 d_model 切分为 num_heads 个子空间"""
def __init__(self, d_model: int, num_heads: int):
super().__init__()
assert d_model % num_heads == 0, "d_model 必须能被 num_heads 整除"
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads # 每个头的维度

# QKV 投影:输入 d_model,输出 d_model(内部包含所有头的投影)
self.W_Q = nn.Linear(d_model, d_model, bias=False)
self.W_K = nn.Linear(d_model, d_model, bias=False)
self.W_V = nn.Linear(d_model, d_model, bias=False)
self.W_O = nn.Linear(d_model, d_model, bias=False)

def forward(self, x: torch.Tensor, mask: torch.Tensor = None):
batch_size, seq_len, _ = x.shape

# Step 1: 线性投影
Q = self.W_Q(x) # (batch, seq_len, d_model)
K = self.W_K(x)
V = self.W_V(x)

# Step 2: 重塑为多头形状,并转置使 head 维度在 seq_len 前面
# (batch, seq_len, d_model) -> (batch, seq_len, num_heads, head_dim)
# -> (batch, num_heads, seq_len, head_dim)
Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

# Step 3: 每个头独立计算注意力分数
# Q @ K^T: (batch, num_heads, seq_len, head_dim)
# @ (batch, num_heads, head_dim, seq_len)
# = (batch, num_heads, seq_len, seq_len)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)

# Step 4: 可选的因果遮蔽
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))

# Step 5: Softmax + 加权求和
attn_weights = F.softmax(scores, dim=-1)
# (batch, num_heads, seq_len, seq_len) @ (batch, num_heads, seq_len, head_dim)
# = (batch, num_heads, seq_len, head_dim)
context = torch.matmul(attn_weights, V)

# Step 6: 多头拼接——先转回 (batch, seq_len, num_heads, head_dim),再合并最后两维
context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)

# Step 7: 输出投影,将拼接后的多头表示混合
output = self.W_O(context) # (batch, seq_len, d_model)
return output

上面的代码完整展示了从输入到输出的每一步。值得注意的是,.transpose(1, 2) 这一步将 (batch, seq_len, num_heads, head_dim) 变成 (batch, num_heads, seq_len, head_dim),这样后续的矩阵乘法就能在每个头内部独立进行,利用 batch 维度的并行性一次算完所有头。


5. 详细维度推导

抽象的维度符号容易让人迷糊。下面用一组具体数值,完整跟踪 Multi-Head Self-Attention 中每一步的张量形状。

5.1 参数设定

1
2
3
4
5
seq_len   = 6        # 序列长度(6 个 token)
d_model = 512 # 模型隐藏维度
num_heads = 8 # 注意力头数
head_dim = 512 / 8 = 64 # 每个头的维度
batch = 1 # 简化为 batch=1

5.2 逐步跟踪

输入

1
X: (1, 6, 512)    # 1 个样本,6 个 token,每个 token 512 维

线性投影

1
2
3
4
W_Q: (512, 512)    # nn.Linear 的权重矩阵
Q = X @ W_Q^T: (1, 6, 512) @ (512, 512) = (1, 6, 512)
K = X @ W_K^T: (1, 6, 512) @ (512, 512) = (1, 6, 512)
V = X @ W_V^T: (1, 6, 512) @ (512, 512) = (1, 6, 512)

切分多头 + 转置

1
2
3
4
5
Q.view(1, 6, 8, 64):         (1, 6, 8, 64)
Q.transpose(1, 2): (1, 8, 6, 64) # 8 个头,每个头看到 6 个 token 的 64 维向量

K reshape 同理: (1, 8, 6, 64)
V reshape 同理: (1, 8, 6, 64)

计算注意力分数

1
2
3
K.transpose(-2, -1):         (1, 8, 64, 6)    # 转置最后两维
Q @ K^T: (1, 8, 6, 64) @ (1, 8, 64, 6) = (1, 8, 6, 6)
# 每个头产生一个 6x6 的注意力分数矩阵

缩放

1
scores / sqrt(64):           (1, 8, 6, 6)    # sqrt(64) = 8,每个分数除以 8

Softmax

1
2
softmax(scores, dim=-1):     (1, 8, 6, 6)    # 每行(最后一维)归一化为概率分布
# 每个 6 维行的元素之和为 1

加权求和

1
2
attn_weights @ V:            (1, 8, 6, 6) @ (1, 8, 6, 64) = (1, 8, 6, 64)
# 每个头输出 6 个 token 的 64 维表示

多头拼接

1
2
transpose(1, 2):             (1, 6, 8, 64)    # 把 head 维度挪回去
view(1, 6, 512): (1, 6, 512) # 8 * 64 = 512,拼接回 d_model

输出投影

1
2
W_O: (512, 512)
context @ W_O^T: (1, 6, 512) @ (512, 512) = (1, 6, 512)

最终输出

1
Output: (1, 6, 512)    # 与输入 X 形状完全一致

从头到尾,输入和输出的形状都是 $(batch, seq_len, d_{model})$,这保证了多个 Transformer Block 可以像积木一样层层堆叠。


6. 为什么要除以 $\sqrt{d_k}$:方差稳定性的数学推导

除以 $\sqrt{d_k}$ 这一步在直觉上容易被当作”调参技巧”一笔带过,但它背后有严格的数学原因。

6.1 问题提出

假设 $Q$ 和 $K$ 的每个分量都是均值为 0、方差为 1 的独立随机变量。我们来计算点积 $q \cdot k$ 的方差。

设 $q = (q_1, q_2, \ldots, q_{d_k})$,$k = (k_1, k_2, \ldots, k_{d_k})$,其中每个 $q_i$ 和 $k_j$ 独立同分布,满足 $E[q_i] = 0$,$\text{Var}(q_i) = 1$。

点积的定义是:

$$
q \cdot k = \sum_{i=1}^{d_k} q_i \cdot k_i
$$

6.2 推导过程

首先,单个分量乘积的期望和方差:

$$
E[q_i \cdot k_i] = E[q_i] \cdot E[k_i] = 0 \times 0 = 0
$$

$$
\text{Var}(q_i \cdot k_i) = E[q_i^2 \cdot k_i^2] - (E[q_i \cdot k_i])^2
$$

由于 $q_i$ 和 $k_i$ 独立:

$$
E[q_i^2 \cdot k_i^2] = E[q_i^2] \cdot E[k_i^2] = \text{Var}(q_i) \cdot \text{Var}(k_i) = 1 \times 1 = 1
$$

所以:

$$
\text{Var}(q_i \cdot k_i) = 1 - 0 = 1
$$

点积是 $d_k$ 个独立随机变量之和,根据方差的可加性:

$$
\text{Var}(q \cdot k) = \sum_{i=1}^{d_k} \text{Var}(q_i \cdot k_i) = d_k
$$

6.3 结论

点积 $q \cdot k$ 的方差是 $d_k$。这意味着当 $d_k = 64$ 时,点积值的标准差大约是 8;当 $d_k = 128$ 时,标准差大约是 11.3。维度越大,点积值的绝对值越大,分布越分散。

Softmax 函数 $\text{softmax}(z_i) = \exp(z_i) / \sum \exp(z_j)$ 对输入值的量级非常敏感。当输入值之间的差距很大时(比如一个值是 50,另一个是 -10),Softmax 的输出会极度集中在最大值上,接近 one-hot 分布。此时梯度几乎为零,参数无法更新,训练陷入停滞。

除以 $\sqrt{d_k}$ 之后:

$$
\text{Var}\left(\frac{q \cdot k}{\sqrt{d_k}}\right) = \frac{\text{Var}(q \cdot k)}{d_k} = \frac{d_k}{d_k} = 1
$$

方差被拉回到 1,无论 $d_k$ 取什么值,Softmax 的输入始终在一个合理的范围内波动,梯度保持健康。这就是 Scaled Dot-Product Attention 名字中 “Scaled” 一词的由来。


7. Softmax 的数值稳定性

Softmax 看起来是一个简单的公式,但在实际的 GPU 实现中,它隐藏着一个容易导致数值溢出的陷阱。

7.1 溢出问题

标准 Softmax 公式:

$$
\text{softmax}(z_i) = \frac{e^{z_i}}{\sum_{j=1}^{N} e^{z_j}}
$$

问题在于 $\exp$ 函数增长极快。当 $z_i$ 值较大时(比如 $z_i = 1000$),$\exp(1000)$ 直接超出 float32 甚至 float16 的表示范围,得到 inf。即使分子分母同时溢出,inf / inf 会得到 NaN,整个计算就废了。

7.2 减最大值技巧

解决方案是利用 Softmax 的平移不变性:对所有输入减去最大值,不改变结果。

$$
\text{softmax}(z_i) = \frac{e^{z_i - \max(z)}}{\sum_{j=1}^{N} e^{z_j - \max(z)}}
$$

数学证明很简单——分子分母同乘以 $\exp(-\max(z))$,等价于分子分母各自除以 $\exp(\max(z))$,值不变。减去最大值后,最大的指数输入是 0,所以 $\exp$ 的结果最大为 1,不会溢出。

但这带来了一个效率问题:标准实现需要对数据做三遍扫描

  1. 第一遍:遍历所有元素,找最大值 $m$
  2. 第二遍:遍历所有元素,计算 $\exp(z_i - m)$ 和它们的总和 $\text{sum}$
  3. 第三遍:遍历所有元素,每个 $\exp(z_i - m)$ 除以 $\text{sum}$

每遍扫描都要从 HBM(高带宽显存)读取数据,三遍就是三次 HBM 读取。当 $N$ 很大时(比如 128K 的序列长度),这个 IO 开销非常可观。

7.3 Online Softmax

Milakov 和 Gimelshein 在 2018 年提出了 Online Softmax 算法,将三遍扫描合并为一遍扫描,核心思想是在遍历数据的同时动态维护最大值和归一化分母。

算法流程如下。在遍历到第 $i$ 个元素时,维护两个变量:

  • $m_i$:前 $i$ 个元素的最大值
  • $d_i$:前 $i$ 个元素的指数和(以 $m_i$ 为基准)

递推公式:

$$
\begin{aligned}
m_i &= \max(m_{i-1},; z_i) \
d_i &= d_{i-1} \cdot e^{m_{i-1} - m_i} + e^{z_i - m_i}
\end{aligned}
$$

关键技巧在第二行:当最大值从 $m_{i-1}$ 更新到 $m_i$ 时,之前累积的指数和 $d_{i-1}$ 需要乘以一个修正因子 $\exp(m_{i-1} - m_i)$ 来”换基”。这样只需一遍扫描就能同时得到最大值和指数和,然后再做一遍扫描完成归一化——总共两遍,比标准实现少一遍。

这个思想对 FlashAttention 至关重要:FlashAttention 将 Attention 矩阵分成小块(tile)逐块处理,每处理一块就需要更新 Softmax 的中间结果。如果不能在线更新 Softmax,就必须把整个 $N \times N$ 矩阵写入 HBM 后再做全局 Softmax,那就失去了分块计算的意义。Online Softmax 让分块计算和精确 Softmax 得以兼容。


8. Multi-Head Attention

8.1 为什么需要多头

单头 Attention 只有一组 QKV 投影,意味着模型只能学习一种”关注模式”。但语言中 token 之间的关系是多维度的——同一个词和其他词之间可能同时存在句法关系(主谓一致)、语义关系(同义替换)、位置关系(相邻词的局部模式)等。

打个比方:在一次项目评审会议上,只派一个评审员去审阅整个项目,他只能从自己擅长的角度提出意见。如果派出一个评审团——一位看技术架构,一位看代码质量,一位看测试覆盖率,一位看文档完整性——每个人独立给出评分和建议,最后汇总成一份综合评审报告,覆盖面就远比单人评审要全面得多。

Multi-Head Attention 就是这个”评审团”机制。每个头有自己独立的 $W_Q$、$W_K$、$W_V$ 投影参数,在不同的子空间中捕捉不同类型的关系。

8.2 数学原理

假设 $d_{model} = 512$,$num_heads = 8$,则 $head_dim = 64$。

Multi-Head Attention 的完整公式:

$$
\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h) \cdot W_O
$$

$$
\text{head}_i = \text{Attention}(Q \cdot W_Q^i, K \cdot W_K^i, V \cdot W_V^i)
$$

其中每个头的投影矩阵 $W_Q^i$、$W_K^i$、$W_V^i$ 的形状是 $(d_{model}, head_dim) = (512, 64)$。但实际实现中,并不会真的维护 $h$ 组小矩阵——而是用一个大矩阵 $W_Q$(形状 $(512, 512)$)做一次投影,然后 reshape 成 $(seq_len, 8, 64)$ 来切分。这样做是等价的:大矩阵可以看作 8 个小矩阵纵向拼接。

8.3 参数量分析

以 $d_{model} = 512$ 为例:

参数矩阵 形状 参数量
$W_Q$ $(512, 512)$ 262,144
$W_K$ $(512, 512)$ 262,144
$W_V$ $(512, 512)$ 262,144
$W_O$ $(512, 512)$ 262,144
合计 $4 \times 512^2 = 1{,}048{,}576$

通用公式:MHA 的参数量为 **$4 \times d_{model}^2$**(不算 bias 的情况)。不管有多少个头,总参数量不变——头数只影响切分方式,不影响总参数。这是因为每增加一个头,每个头的维度相应减小,两者乘积(即总投影维度)始终等于 $d_{model}$。

8.4 多头的工程意义

多头结构天然适合并行化。8 个头的计算完全独立,可以:

  • GPU 内并行:利用 batch 维度,在一次 CUDA kernel 启动中同时处理所有头
  • 多卡张量并行(Tensor Parallelism):将不同头分配到不同 GPU 上,每张 GPU 只计算自己负责的若干头。比如 8 个头分到 4 张 GPU,每张处理 2 个头。最后通过一次 AllReduce 通信汇总输出投影的结果

9. MHA vs MQA vs GQA vs MLA

随着大模型进入推理效率至上的时代,Attention 头的结构也在不断演化。核心驱动力是一个问题:KV Cache 太大了

9.1 MHA (Multi-Head Attention)

标准的 MHA 中,每个注意力头都有独立的 Q、K、V。如果有 $h$ 个头,那就有 $h$ 组 K 和 $h$ 组 V 需要缓存。

  • $h$ 个 Q 头,$h$ 个 K 头,$h$ 个 V 头
  • 每组 Q/K/V 独立,互不共享

9.2 MQA (Multi-Query Attention)

Shazeer 在 2019 年提出:所有 Q 头共享同一组 K 和 V。也就是说,只有 1 个 K 头和 1 个 V 头,但仍然有 $h$ 个 Q 头。

  • $h$ 个 Q 头,1 个 K 头,1 个 V 头
  • 所有 Q 头共享同一份 K 和 V

好处是 KV Cache 缩小为原来的 $1/h$,推理速度大幅提升。代价是模型表达能力下降——所有头被迫从同一组 KV 中提取信息,多样性受限。

9.3 GQA (Grouped-Query Attention)

GQA 是 MHA 和 MQA 的折中方案(Ainslie et al., 2023)。将 $h$ 个 Q 头分成 $g$ 个组(每组 $h/g$ 个 Q 头),每组共享一组 K 和 V。

  • $h$ 个 Q 头,$g$ 个 K 头,$g$ 个 V 头
  • 每 $h/g$ 个 Q 头共享一组 K 和 V
  • 当 $g = h$ 时退化为 MHA,当 $g = 1$ 时退化为 MQA

GQA 在模型质量和推理效率之间取得了更好的平衡。LLaMA-2-70B、Mistral-7B 等主流模型都采用了 GQA。

9.4 对比表格

以 $d_{model} = 4096$,$num_heads = 32$,$head_dim = 128$ 为例,假设序列长度 $N = 4096$,FP16 精度:

指标 MHA ($g=32$) GQA ($g=8$) MQA ($g=1$)
Q 头数 32 32 32
KV 头数 32 8 1
$W_K$ 参数量 $32 \times 128 \times 4096 = 16\text{M}$ $8 \times 128 \times 4096 = 4\text{M}$ $1 \times 128 \times 4096 = 0.5\text{M}$
$W_V$ 参数量 $16\text{M}$ $4\text{M}$ $0.5\text{M}$
单 token KV Cache $2 \times 32 \times 128 \times 2\text{B} = 16\text{KB}$ $2 \times 8 \times 128 \times 2\text{B} = 4\text{KB}$ $2 \times 1 \times 128 \times 2\text{B} = 0.5\text{KB}$
4096 token KV Cache / 层 $64\text{MB}$ $16\text{MB}$ $2\text{MB}$
模型质量 最好 接近 MHA 有下降

可以看到,从 MHA 到 GQA($g=8$),KV Cache 缩小为 1/4,参数量减少,但模型质量几乎不受影响。这就是 GQA 成为主流选择的原因。

9.5 MLA (Multi-Latent Attention)

DeepSeek-V2 提出了 MLA(Multi-Latent Attention),代表了另一种压缩 KV Cache 的思路。

MLA 的核心思想是:不直接缓存完整的 K 和 V 向量,而是将它们压缩到一个低维潜在空间(latent space),只缓存这个低维表示。推理时再将低维表示解压回完整的 K 和 V。

  • 标准 MHA:$X \to W_K \to K$(缓存 $K$,维度 = num_kv_heads $\times$ head_dim)
  • MLA:$X \to W_{DKV} \to c$(缓存 $c$,维度远小于 $K$ 的维度),推理时 $c \to W_{UK} \to K$ 按需解压

MLA 通过低秩压缩将 KV Cache 大幅缩小(DeepSeek-V2 报告了约 93.3% 的压缩率),同时通过精心设计的上投影矩阵保持了模型质量。它与 GQA 的思路不同——GQA 是减少 KV 头的数量,MLA 是降低每个表示的维度——但目标一致:让推理时的 KV Cache 尽可能小。


10. Causal Mask 的作用与实现

10.1 为什么需要遮蔽

在自回归语言模型(如 GPT 系列、LLaMA 系列)中,模型的训练目标是”根据前文预测下一个 token”。这要求在计算 Attention 时,每个位置的 token 只能看到自己和它之前的 token,不能”偷看”未来的信息。否则,模型在训练时就能看到答案,学不到任何有意义的预测能力。

10.2 实现方式

Causal Mask(因果遮蔽)是一个上三角矩阵,覆盖在注意力分数矩阵上,将未来位置的分数设为负无穷:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# 构造因果遮蔽矩阵
def create_causal_mask(seq_len: int) -> torch.Tensor:
"""
生成一个下三角矩阵(含对角线),上三角部分为 False。
为 True 的位置保留,为 False 的位置在 Attention 分数中被设为 -inf。
"""
mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool))
return mask

# 使用示例:seq_len = 4
# mask:
# [[True, False, False, False],
# [True, True, False, False],
# [True, True, True, False],
# [True, True, True, True ]]
#
# 被遮蔽后的 Attention 分数矩阵(上三角为 -inf):
# [[s00, -inf, -inf, -inf],
# [s10, s11, -inf, -inf],
# [s20, s21, s22, -inf],
# [s30, s31, s32, s33 ]]
#
# Softmax 之后,-inf 位置的权重变为 0,未来信息被完全屏蔽

Causal Mask 的另一个工程意义在于:它使得 Attention 矩阵只有下三角部分有效,理论上可以跳过上三角部分的计算。FlashAttention-2 正是利用了这一点——在处理上三角区域的 tile 时直接跳过,减少了接近一半的计算量。


11. Self-Attention 的计算瓶颈分析

理解 Self-Attention 的性能瓶颈,需要区分两种不同的场景。

11.1 计算量分析

Self-Attention 中有四次主要的矩阵乘法:

运算 形状 FLOPs
$Q = X \cdot W_Q$ $(N, d) \times (d, d)$ $2Nd^2$
$K = X \cdot W_K$ $(N, d) \times (d, d)$ $2Nd^2$
$V = X \cdot W_V$ $(N, d) \times (d, d)$ $2Nd^2$
$S = Q \cdot K^T$ $(N, d) \times (d, N)$ $2N^2d$
$O = A \cdot V$ $(N, N) \times (N, d)$ $2N^2d$
$Final = O \cdot W_O$ $(N, d) \times (d, d)$ $2Nd^2$

总计:$8Nd^2 + 4N^2d$。

当 $N$ 较小(比如 $N < d$)时,$8Nd^2$ 项主导——瓶颈在 QKV 投影的 GEMM 运算。当 $N$ 较大(比如 $N > d$)时,$4N^2d$ 项主导——瓶颈在 Attention 矩阵的计算。

11.2 Compute Bound vs Memory Bound

GPU 的性能受两个指标限制:

  • 算力(Compute):每秒能做多少次浮点运算(FLOPS)
  • 带宽(Memory Bandwidth):每秒能从 HBM 读写多少数据(GB/s)

一个运算是 Compute Bound 还是 Memory Bound,取决于它的算术强度(Arithmetic Intensity)= FLOPs / Bytes。

  • 如果算术强度 > GPU 的算力/带宽比值,运算是 Memory Bound——GPU 算力用不满,大部分时间在等数据搬运
  • 如果算术强度 < GPU 的算力/带宽比值,运算是 Compute Bound——GPU 带宽够用,大部分时间在做计算

以 A100 为例:FP16 算力约 312 TFLOPS,HBM 带宽约 2 TB/s,算力/带宽比值约 156 FLOP/Byte。

Prefill 阶段(处理完整 prompt):

QKV 投影是大矩阵乘法,batch 维度(seq_len)很大,算术强度高,通常是 Compute Bound。Attention 的 $Q \cdot K^T$ 也是大矩阵乘,同样是 Compute Bound。

Decode 阶段(逐 token 生成):

每步只有 1 个 token,QKV 投影退化为矩阵-向量乘法(GEMV),算术强度极低,是 Memory Bound。Attention 退化为一个向量和整个 KV Cache 的运算,同样 Memory Bound。大部分时间花在从 HBM 搬运模型权重和 KV Cache 上。

11.3 $O(N^2)$ 的显存瓶颈

标准 Attention 实现需要显式地计算并存储完整的 $N \times N$ 注意力矩阵。以 $seq_len = 128\text{K}$、$num_heads = 32$ 为例:

$$
32 \times 128\text{K} \times 128\text{K} \times 2\text{B} = 32 \times 16{,}384\text{M} \times 2 = 1{,}048{,}576\text{ MB} = 1\text{ TB}
$$

这显然远超任何单张 GPU 的显存容量。即使序列长度”只有”8K,注意力矩阵也需要 $32 \times 8\text{K} \times 8\text{K} \times 2\text{B} = 4\text{ GB}$,已经是一个不可忽视的开销。

这个 $O(N^2)$ 的显存瓶颈,正是 FlashAttention 要解决的核心问题。


12. FlashAttention 的核心思想

FlashAttention(Dao et al., 2022)是过去几年 Attention 优化领域最具影响力的工作。它不改变 Attention 的计算结果(是精确计算,不是近似),但通过重新编排计算顺序,大幅减少了对 HBM 的访问量。

12.1 GPU 存储层次回顾

要理解 FlashAttention,需要先了解 GPU 的存储层次:

  • HBM(High Bandwidth Memory):GPU 的主显存,容量大(如 A100 的 80 GB)但访问速度相对慢(2 TB/s)
  • SRAM(片上缓存):每个 SM(Streaming Multiprocessor)上的共享内存和寄存器,容量很小(如 A100 每个 SM 约 192 KB 共享内存,全部 SM 合计约 20 MB)但访问速度极快(约 19 TB/s)

标准 Attention 的流程是:在 HBM 中完成 $Q \cdot K^T$,把完整的 $N \times N$ 注意力矩阵写回 HBM;然后从 HBM 读取注意力矩阵做 Softmax,结果写回 HBM;最后从 HBM 读取 Softmax 结果和 $V$ 做矩阵乘法。每一步都涉及对 $N \times N$ 大矩阵的 HBM 读写——这是巨大的带宽浪费。

12.2 Tiling:分块计算

FlashAttention 的第一个关键技术是 Tiling(分块):不一次性计算完整的 $N \times N$ 注意力矩阵,而是将 Q、K、V 分成若干小块(tile),每次只加载一小块 Q 和一小块 K、V 到 SRAM 中,在 SRAM 内完成该块的 Attention 计算,然后将结果写回 HBM。

1
2
3
4
5
6
7
8
9
10
11
标准实现:
Q, K 全部加载到 HBM → 计算完整 N x N 矩阵 → 写回 HBM → 读取做 Softmax → 写回 HBM → ...

FlashAttention:
将 Q 分成 T_r 块,K/V 分成 T_c 块
For 每一块 Q_i:
For 每一块 K_j, V_j:
从 HBM 加载 Q_i, K_j, V_j 到 SRAM(每块很小,能装下)
在 SRAM 内计算 Q_i @ K_j^T → 局部 Softmax → 乘以 V_j
累积到该块 Q_i 对应的输出中
将 Q_i 块的最终输出写回 HBM

每次只处理一个 tile,注意力矩阵的这一小块始终驻留在 SRAM 中,永远不需要把完整的 $N \times N$ 矩阵写入 HBM。

12.3 Online Softmax 的关键作用

Tiling 带来一个棘手的问题:Softmax 需要对每一行的所有元素做归一化(需要全局最大值和全局求和),但分块计算时每次只看到一行的一部分——怎么在只看到局部数据的情况下计算全局 Softmax?

这就是 Online Softmax 派上用场的地方。回忆第 7.3 节的 Online Softmax 递推公式:当新数据块到来时,可以动态更新全局最大值和归一化分母,并修正之前块的计算结果。

具体来说,处理 $Q_i$ 对 $K_1$ 块的 Attention 后,得到一个局部输出 $O_1$ 和对应的 Softmax 统计量(局部最大值 $m_1$ 和局部分母 $l_1$)。当继续处理 $Q_i$ 对 $K_2$ 块时,会得到新的局部统计量 $m_2$ 和 $l_2$。通过 Online Softmax 的修正:

$$
\begin{aligned}
m_{\text{new}} &= \max(m_1, m_2) \
l_{\text{new}} &= l_1 \cdot e^{m_1 - m_{\text{new}}} + l_2 \cdot e^{m_2 - m_{\text{new}}} \
O_{\text{new}} &= \frac{O_1 \cdot l_1 \cdot e^{m_1 - m_{\text{new}}} + O_2^{\text{local}} \cdot l_2 \cdot e^{m_2 - m_{\text{new}}}}{l_{\text{new}}}
\end{aligned}
$$

不断累积,直到所有 $K$ 块处理完毕,最终结果与标准 Attention 的全局 Softmax 在数学上完全一致。

12.4 为什么能减少 HBM 访问

标准 Attention 的 HBM 访问量:

  • 写入 $S$ ($N \times N$):$O(N^2)$ 次写
  • 读取 $S$ 做 Softmax:$O(N^2)$ 次读
  • 写入 $P = \text{Softmax}(S)$:$O(N^2)$ 次写
  • 读取 $P$ 做 $P \cdot V$:$O(N^2)$ 次读
  • **总 HBM 访问量:$O(N^2)$**(加上 Q, K, V 本身的 $O(Nd)$ 读取)

FlashAttention 的 HBM 访问量:

  • 读取 Q, K, V:$O(Nd)$
  • 写入最终输出 O:$O(Nd)$
  • 中间的注意力矩阵始终在 SRAM 中,不写入 HBM
  • 总 HBM 访问量:$O(Nd)$

从 $O(N^2)$ 降到 $O(Nd)$——当 $N$ 远大于 $d$ 时(长序列场景),这是一个数量级的改进。注意,计算量(FLOPs)没有变——仍然是 $O(N^2d)$——改变的只是 IO 模式。FlashAttention 的加速本质上来自于减少了对慢速 HBM 的访问,让计算尽可能在快速 SRAM 中完成

12.5 FlashAttention-2 和 FlashAttention-3

FlashAttention-2(Dao, 2023)在 FlashAttention 的基础上进一步优化了并行策略和 warp 调度:

  • 调整了内外循环的顺序(外循环遍历 Q 块,内循环遍历 K/V 块),减少了共享内存的读写
  • 更好的 warp 级工作分配,提升了 GPU 的占用率
  • 利用 Causal Mask 跳过全零的 tile,进一步减少计算量

FlashAttention-3(Dao et al., 2024)则针对 Hopper 架构(H100)进行了优化,利用了 TMA(Tensor Memory Accelerator)和 wgmma(Warp Group MMA)指令,进一步提升了硬件利用率。

目前 FlashAttention 已经是所有主流推理和训练框架的标配——PyTorch 2.0+ 内置了 torch.nn.functional.scaled_dot_product_attention,底层默认调用 FlashAttention kernel。


🎯 自我检验清单

完成本文学习后,用以下问题检验自己的理解深度:

  • 能从 Bahdanau Attention 到 Self-Attention 梳理出 Attention 机制的演进脉络,说清每一步解决了什么问题
  • 能默写 Scaled Dot-Product Attention 的完整公式,并解释 Q、K、V 三元组各自的作用
  • 能从零推导 “除以 $\sqrt{d_k}$” 的数学原因:假设 $q_i, k_i$ 独立标准正态,推出点积方差为 $d_k$
  • 能手写 Single-Head 和 Multi-Head Self-Attention 的 PyTorch 实现,能说清每一步的张量维度变化
  • 能解释 Softmax 为什么要减最大值,以及 Online Softmax 的核心递推思想
  • 给定 $seq_len$、$d_{model}$、$num_heads$ 的具体数值,能完整跟踪每一步的张量形状
  • 能画出 MHA / MQA / GQA 的结构差异,并估算各自的 KV Cache 大小
  • 能说清 Causal Mask 的作用和实现方式(上三角设为负无穷,Softmax 后归零)
  • 能区分 Compute Bound 和 Memory Bound,并说清 Prefill 和 Decode 阶段各自的瓶颈类型
  • 能说清 FlashAttention 的核心思想:Tiling + Online Softmax 如何将 HBM 访问从 $O(N^2)$ 降到 $O(Nd)$

📚 参考资料

论文

教程与博客