2.7 Transformer Decoder Block完整解析
大语言模型的核心计算单元是 Transformer Decoder Block。无论你在做 CUDA 算子优化、分布式训练还是推理部署,最终操作的对象都是这个 Block 里面的矩阵乘法、归一化和注意力计算。本文将这个 Block 彻底拆开,从架构选型的历史原因讲起,逐步深入到因果掩码的实现、完整的 PyTorch 代码、参数量与计算量的手算方法,最后落地到显存规划的工程实践。目标是读完之后,你能拿着纸笔算清楚任意一个开源模型”能不能装进某张卡”。
📑 目录
- 1. 架构选型:为什么 Decoder-only 成为主流
- 2. Decoder Block 完整数据流
- 3. Causal Mask:因果掩码详解
- 4. PyTorch 实现:从零搭建 Decoder Block
- 5. 主流模型维度配置对比
- 6. 参数量手算教学
- 7. 计算量(FLOPs)估算
- 8. 显存规划详解
- 9. 主流开源模型架构对比
- 自我检验清单
- 参考资料
1. 架构选型:为什么 Decoder-only 成为主流
1.1 三种 Transformer 架构回顾
2017 年的原始论文 “Attention Is All You Need” 提出的是一个 Encoder-Decoder 架构:Encoder 负责理解输入,Decoder 负责生成输出。此后演化出三条路线:
| 架构 | 代表模型 | 核心特点 | 典型任务 |
|---|---|---|---|
| Encoder-only | BERT, RoBERTa | 双向注意力,看到完整上下文 | 分类、NER、句子相似度 |
| Encoder-Decoder | T5, BART, mBART | Encoder 双向理解,Decoder 自回归生成 | 翻译、摘要、问答 |
| Decoder-only | GPT 系列, LLaMA, Mistral, Qwen | 单向因果注意力,自回归生成 | 通用文本生成、对话、推理 |
打个比方:Encoder-only 像一个阅读理解专家,擅长”读懂”但不会”写作”;Encoder-Decoder 像一个翻译官,需要先完整理解原文再逐句翻译;Decoder-only 则像一个即兴演讲者,边想边说,每句话只基于前面已经说过的内容。
1.2 Decoder-only 胜出的原因
进入大模型时代后,Decoder-only 几乎一统天下。这并非巧合,背后有多重原因:
统一的训练范式。Decoder-only 的训练目标极其简单——预测下一个 token。无论输入是什么语言、什么任务,训练信号都是统一的。相比之下,Encoder-Decoder 需要设计”输入-输出”对,数据构造更复杂。当你有数万亿 token 的无标注文本时,”预测下一个词”是最自然、最高效的利用方式。
规模扩展(Scaling)更简单。Decoder-only 架构只有一种 Block 不断堆叠,想扩大模型只需增加层数或隐藏维度。Encoder-Decoder 架构需要同时扩展两个组件,还得平衡二者的比例,调参空间更大。OpenAI 的 Scaling Laws 研究表明,在固定计算预算下,Decoder-only 架构的参数效率与 Encoder-Decoder 基本相当,但工程复杂度低得多。
推理效率的优势。Decoder-only 架构在推理时只需维护一套 KV Cache,而 Encoder-Decoder 需要维护 Encoder 侧的输出和 Decoder 侧的 KV Cache 两套数据。对于长上下文的对话场景,简单的 KV Cache 管理意味着更高的系统吞吐量。
涌现能力的经验观察。实践中发现,当模型参数量超过一定阈值后,Decoder-only 架构在 few-shot 和 zero-shot 场景下展现出显著的涌现能力(如 chain-of-thought 推理)。这些能力在 Encoder-only 和 Encoder-Decoder 架构上不那么明显。
总结一句话:Decoder-only 在大规模场景下,以最简洁的架构获得了最好的通用能力,同时对工程系统最友好。 从 GPT-3 到 LLaMA,再到 Mistral、Qwen、DeepSeek,全部采用这一路线。
2. Decoder Block 完整数据流
一个 Pre-Norm Decoder-only Block 内部的数据流转可以用下面的图来表示。相比入门篇的简化版本,这里标注了每一步的具体操作和维度变化:
1 | Input x (seq_len, d_model) |
几个值得注意的细节:
RMSNorm 而非 LayerNorm。LLaMA 系列及后续大多数模型使用 RMSNorm(Root Mean Square Normalization),省去了减均值的步骤,只保留除以均方根的操作。计算更简单,效果相当。
RoPE 只作用于 Q 和 K。旋转位置编码的作用是让 Q 和 K 的内积包含相对位置信息,V 不需要旋转,因为 V 承载的是”内容信息”而非”位置匹配信号”。
Pre-Norm 的残差路径是干净的。注意图中两条残差路径都是从 Add 节点直接拉过来的,中间没有经过任何变换。这保证了梯度可以无损地沿残差路径回传,是深层模型训练稳定的关键。
3. Causal Mask:因果掩码详解
3.1 为什么需要因果掩码
Decoder-only 模型的训练目标是”给定前面的 token,预测下一个 token”。这意味着在计算第 i 个 token 的 Attention 时,它只能看到位置 0 到 i 的信息,不能”偷看”位置 i+1 及之后的 token——否则预测任务就变成了”开卷考试”,模型学不到任何东西。
这个约束在自然语言生成中是合理的:你在说第五个字的时候,确实不知道第六个字会是什么。模型必须遵守同样的因果顺序。
在训练时,为了效率,我们把整个序列一次性送入模型并行计算(而非逐 token 送入)。但并行计算意味着所有 token 的 Attention 是同时算的,如果不加限制,每个 token 都会”看到”整个序列。因果掩码(Causal Mask)就是用来在并行计算的同时强制执行”只能看过去”的约束。
3.2 掩码矩阵的可视化
假设序列长度为 5,Attention 分数矩阵 S 的形状是 (5, 5)。S[i][j] 表示第 i 个 token 对第 j 个 token 的原始注意力分数。因果掩码要求:当 j > i(即 j 在 i 后面)时,S[i][j] 必须被屏蔽。
掩码矩阵 M 长这样(0 表示保留,-inf 表示屏蔽):
1 | Token: t0 t1 t2 t3 t4 |
将 M 加到 S 上之后,被屏蔽的位置变成负无穷。接下来做 softmax 时,$e^{-\infty} = 0$,这些位置的注意力权重就变成了零。
用另一种直观的方式表示——哪些位置能被”看到”(用 1 表示可见,0 表示不可见):
1 | Token: t0 t1 t2 t3 t4 |
这是一个下三角矩阵,因此因果掩码也被称为”下三角掩码”或”上三角掩码”(取决于你说的是保留区域还是屏蔽区域)。
3.3 实现方式
在代码层面,因果掩码的实现非常简洁:
1 | import torch |
使用时直接加到 Attention 分数上:
1 | scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k) |
在实际的高性能实现中(如 FlashAttention),因果掩码不是显式地构造一个 N x N 矩阵然后相加,而是在 tiled 计算过程中通过索引判断来跳过不需要计算的上三角区域,既节省显存又减少无效计算。
4. PyTorch 实现:从零搭建 Decoder Block
下面是一个完整的、可运行的 Transformer Decoder Block 实现,包含 Pre-Norm(RMSNorm)、Masked Multi-Head Attention(带 RoPE)、SwiGLU FFN 和残差连接。
1 | import torch |
代码说明几个关键点:
- RMSNorm 只有一个可学习参数
weight(即 $\gamma$),没有 bias。它对每个 token 的特征向量计算均方根,然后归一化。 - RoPE 是预计算的,不引入额外可学习参数。只对 Q 和 K 旋转,V 保持不变。
- SwiGLU FFN 有三个权重矩阵,其中 $W_{gate}$ 的输出经过 SiLU 激活后与 $W_{up}$ 的输出逐元素相乘,形成门控机制,最后由 $W_{down}$ 降维。
- 残差连接在 Block 的
forward中用简单的加法实现:h = x + self.attn(self.norm1(x), ...)。
5. 主流模型维度配置对比
不同规模的模型本质上是同一套 Decoder Block 结构,只是维度配置不同。下表汇总了几个标志性模型的关键参数:
| 模型 | $d_{model}$ | $h$ | $h_{kv}$ | $d_k$ | $L$ | $d_{ff}$ | $V$ | max_seq_len |
|---|---|---|---|---|---|---|---|---|
| LLaMA-2-7B | 4096 | 32 | 32 (MHA) | 128 | 32 | 11008 | 32000 | 4096 |
| LLaMA-2-13B | 5120 | 40 | 40 (MHA) | 128 | 40 | 13824 | 32000 | 4096 |
| LLaMA-2-70B | 8192 | 64 | 8 (GQA) | 128 | 80 | 28672 | 32000 | 4096 |
| GPT-3 175B | 12288 | 96 | 96 (MHA) | 128 | 96 | 49152 | 50257 | 2048 |
| Mistral-7B | 4096 | 32 | 8 (GQA) | 128 | 32 | 14336 | 32000 | 32768 |
几个值得注意的规律:
$d_k$(每头维度)基本恒定为 128。无论模型多大,每个注意力头处理的维度都是 128。模型变大时增加的是头的数量($h$)和层数($L$),而非单头维度。这是因为 128 维已经足以让每个头捕捉一种有意义的注意力模式。
$d_{ff}$ 与 $d_{model}$ 的比例。标准 Transformer 中 $d_{ff} = 4 \times d_{model}$。但使用 SwiGLU 后,为了保持总参数量不变(三个矩阵 vs 两个矩阵),通常取 $d_{ff} = (8/3) \times d_{model}$,再向上取整到某个方便的数。比如 LLaMA-2-7B 的 $(8/3) \times 4096 = 10922.67$,取整到 11008(256 的倍数,有利于 GPU 计算对齐)。Mistral-7B 的 $d_{ff} = 14336$ 略大,因为其设计选择了更大的 FFN。
GQA 的引入。LLaMA-2-70B 和 Mistral-7B 使用了 Grouped-Query Attention:KV 头数少于 Q 头数。LLaMA-2-70B 用 8 个 KV 头服务 64 个 Q 头(每组 8 个 Q 头共享 1 组 KV),这将 KV Cache 减少到 MHA 的 1/8,大幅降低推理时的显存消耗。
6. 参数量手算教学
能手算模型参数量是 AI Infra 工程师的基本功。知道参数量,才能估算显存需求、通信开销和训练成本。
6.1 通用公式推导
一个标准的 Decoder-only 模型由以下部分组成:
(1)Token Embedding 层
将 token ID 映射为向量:
$$P_{embed} = V \times d$$
其中 $V$ 是词表大小,$d$ 是 $d_{model}$。
(2)单个 Decoder Block
对于使用 MHA(所有头独立的 KV)和 SwiGLU FFN 的 Block:
Attention 部分四个投影矩阵(不含 bias):
$$P_{attn} = 4 \times d^2$$
为什么是 $4d^2$?$W_Q$、$W_K$、$W_V$、$W_O$ 每个都是 $(d, d)$ 的矩阵,每个有 $d^2$ 个参数。
如果使用 GQA(KV 头数为 $h_{kv}$,Q 头数为 $h_q$,每头维度 $d_k$):
$$P_{attn} = d \times (h_q \cdot d_k) + 2 \times d \times (h_{kv} \cdot d_k) + (h_q \cdot d_k) \times d$$
$$= d^2 + 2 \times d \times h_{kv} \times d_k + d^2$$
SwiGLU FFN 部分三个矩阵(不含 bias):
$$P_{ffn} = 3 \times d \times d_{ff}$$
其中 $d_{ff}$ 是 FFN 中间维度。
RMSNorm 部分(两个,每个有 $d$ 个参数):
$$P_{norm} = 2 \times d$$
单个 Block 合计:
$$P_{block} = 4d^2 + 3 \times d \times d_{ff} + 2d \quad (\text{MHA 情况})$$
(3)最终的 RMSNorm + 输出头(LM Head)
$$P_{final} = d + V \times d$$
第一项是最终 RMSNorm 的参数,第二项是 LM Head 的参数(将 $d_{model}$ 映射到 $V$)。
(4)模型总参数量
$$P_{total} = V \times d + L \times P_{block} + d + V \times d$$
其中 $L$ 是层数。如果 Embedding 和 LM Head 共享权重(Weight Tying),则减去一个 $V \times d$。
6.2 详细计算:以 LLaMA-2-7B 为例
配置回顾:$d = 4096$, $h = 32$, $d_k = 128$, $d_{ff} = 11008$, $L = 32$, $V = 32000$
Attention 参数量(单 Block):
| 矩阵 | 形状 | 参数量 |
|---|---|---|
| $W_Q$ | 4096 x 4096 | 16,777,216 |
| $W_K$ | 4096 x 4096 | 16,777,216 |
| $W_V$ | 4096 x 4096 | 16,777,216 |
| $W_O$ | 4096 x 4096 | 16,777,216 |
| Attention 小计 | 67,108,864 (67.1M) |
验证:$4 \times 4096^2 = 4 \times 16{,}777{,}216 = 67{,}108{,}864$
FFN 参数量(单 Block):
| 矩阵 | 形状 | 参数量 |
|---|---|---|
| $W_{gate}$ | 4096 x 11008 | 45,088,768 |
| $W_{up}$ | 4096 x 11008 | 45,088,768 |
| $W_{down}$ | 11008 x 4096 | 45,088,768 |
| FFN 小计 | 135,266,304 (135.3M) |
验证:$3 \times 4096 \times 11008 = 3 \times 45{,}088{,}768 = 135{,}266{,}304$
RMSNorm 参数量(单 Block):
$2 \times 4096 = 8{,}192$
单 Block 合计:
$67{,}108{,}864 + 135{,}266{,}304 + 8{,}192$ = 202,383,360 (~202M)
其中 FFN 占比 = 135.3M / 202.4M = **66.8%**,Attention 占比 = 67.1M / 202.4M = 33.2%
整个模型:
| 组件 | 计算 | 参数量 |
|---|---|---|
| Token Embedding | $32000 \times 4096$ | 131,072,000 |
| 32 层 Block | $32 \times 202{,}383{,}360$ | 6,476,267,520 |
| 最终 RMSNorm | 4096 | 4,096 |
| LM Head | $4096 \times 32000$ | 131,072,000 |
| 总计 | 6,738,415,616 (~6.74B) |
如果 Embedding 和 LM Head 共享权重,则减去 131M,约 6.61B。官方标注 “7B” 是近似值。
6.3 练习:LLaMA-2-13B 参数量
配置:$d = 5120$, $h = 40$, $d_k = 128$, $d_{ff} = 13824$, $L = 40$, $V = 32000$
读者可以先自己算,再对照下面的答案:
单 Block:
- Attention: $4 \times 5120^2 = 4 \times 26{,}214{,}400 = 104{,}857{,}600$ (104.9M)
- FFN: $3 \times 5120 \times 13824 = 3 \times 70{,}778{,}880 = 212{,}336{,}640$ (212.3M)
- RMSNorm: $2 \times 5120 = 10{,}240$
- 单 Block 合计: $317{,}204{,}480$ (~317M)
整个模型:
- Embedding: $32000 \times 5120 = 163{,}840{,}000$
- 40 层 Block: $40 \times 317{,}204{,}480 = 12{,}688{,}179{,}200$
- 最终 RMSNorm: $5{,}120$
- LM Head: $5120 \times 32000 = 163{,}840{,}000$
- 总计: $13{,}015{,}864{,}320$ (~13.0B)
与官方标注的 13B 吻合。
6.4 Weight Tying 技术
不少模型会让 Token Embedding 和 LM Head 共享同一个权重矩阵。这个技巧叫做 Weight Tying,最早在 2017 年由 Press 和 Wolf 提出。
直觉上可以这样理解:Embedding 层的工作是”将 token ID 映射为语义向量”(从离散空间到连续空间),LM Head 的工作是”将语义向量映射回 token 概率”(从连续空间到离散空间)。这两个操作是互逆的,使用相同的权重矩阵(一个用正矩阵,一个用其转置)是合理的。
Weight Tying 的好处:
- 节省参数:对于 $V=32000$, $d=4096$ 的模型,节省 131M 参数,约占 7B 模型的 2%
- 正则化效果:共享权重相当于一种隐式的约束,防止 Embedding 空间和输出空间”漂移”
- 减少显存:少存储一个 $(V, d)$ 矩阵
LLaMA-2 系列使用了 Weight Tying,而 GPT-3 没有使用。是否使用取决于词表大小与模型大小的比例——当词表相对于模型较小时(如 32000 vs 7B),共享带来的参数节省比例很小,但正则化效果仍然有意义。
7. 计算量(FLOPs)估算
知道参数量可以算显存,而知道计算量(FLOPs)则可以估算训练时间和硬件利用率。
7.1 矩阵乘法的 FLOPs
计算量估算的基础是矩阵乘法。一个 $(M, K) \times (K, N)$ 的矩阵乘法:
- 结果矩阵有 $M \times N$ 个元素
- 每个元素需要 $K$ 次乘法和 $K-1$ 次加法
- 总 FLOPs 约为 $2 \times M \times K \times N$(乘法和加法各算一次浮点操作)
7.2 单次前向传播的 FLOPs
对于一个有 $N$ 个参数的 Transformer 模型,处理长度为 $s$ 的序列时,前向传播的计算量近似为:
$$FLOPs_{forward} \approx 2 \times N \times s$$
这就是 “$2N$” 经验法则。这里的直觉是:模型的主要计算都是矩阵乘法,每个参数在前向传播中恰好参与一次矩阵乘法,贡献约 2 FLOPs(一次乘法一次加法),再乘以序列中的 $s$ 个 token。
更精确地说,这个 $2Ns$ 只计算了线性层(GEMM)的 FLOPs,忽略了 Attention 中 $QK^T$ 和 $AV$ 的计算量(这部分与序列长度的平方成正比)。完整的公式是:
$$FLOPs_{forward} \approx 2Ns + 2 \times L \times h \times s^2 \times d_k$$
其中第二项是 Attention 的计算量(每层每个头有两个 $(s, d_k) \times (d_k, s)$ 的矩阵乘法)。当序列长度 $s$ 较短时(比如 2048),第二项远小于第一项,”$2Ns$” 是一个好的近似。当 $s$ 很长(如 128K)时,Attention 的计算量可能接近甚至超过线性层。
7.3 训练 vs 推理的 FLOPs
训练时的 FLOPs
训练包括前向传播和反向传播。经验上,反向传播的计算量约为前向传播的 2 倍(需要计算对权重和对输入的梯度)。因此:
$$FLOPs_{train} \approx 3 \times FLOPs_{forward} = 6Ns \quad (\text{per token})$$
对于整个训练过程,如果训练了 $T$ 个 token:
$$FLOPs_{total} = 6 \times N \times T$$
举个例子:LLaMA-2-7B 用 2 万亿 token 训练:
- FLOPs $= 6 \times 6.7\text{B} \times 2\text{T} = 6 \times 6.7 \times 10^9 \times 2 \times 10^{12} = 8.04 \times 10^{22}$
如果使用 1000 张 A100(BF16 峰值算力 312 TFLOPS 每张,假设 MFU=50%):
- 有效算力 $= 1000 \times 312 \times 10^{12} \times 0.5 = 1.56 \times 10^{17}$ FLOPS
- 训练时间 $= 8.04 \times 10^{22} / 1.56 \times 10^{17} = 515{,}385$ 秒 = 约 6 天
这与实际公开的训练时间量级相符。
推理时的 FLOPs
推理只有前向传播,per token 约 $2N$ FLOPs。但推理的特殊之处在于 Decode 阶段每步只处理 1 个 token,矩阵乘法退化为矩阵-向量乘(GEMV),GPU 的算力远远用不满,瓶颈变成了显存带宽而非计算能力。所以推理优化更关注显存带宽(Memory Bound)而非峰值算力。
7.4 估算实例
以 LLaMA-2-7B 推理为例,在 A100-80GB 上:
- 单 token Decode FLOPs $= 2 \times 6.7\text{B} = 13.4$ GFLOPs
- A100 FP16 峰值算力 $= 312$ TFLOPS
- 理论上 FLOPs 只需 $13.4\text{G} / 312\text{T} = 0.043\text{ms}$
但实际一个 token 的 Decode 时间约为 10-20ms。为什么差了几百倍?因为 Decode 是 Memory Bound:需要从 HBM 搬运全部模型权重(13.4 GB FP16),A100 的 HBM 带宽是 2 TB/s,光搬权重就需要 13.4 GB / 2 TB/s = 6.7ms,再加上 KV Cache 的搬运和其他开销,10-20ms 就合理了。
8. 显存规划详解
显存规划是 AI Infra 工程中最常见的实操问题:给定一个模型和一张(或多张)GPU,判断能不能放下,如果放不下该怎么办。
8.1 模型权重显存
模型权重的显存取决于参数量和存储精度:
| 精度格式 | 每个参数占用 | 7B 模型权重大小 | 13B 模型权重大小 | 70B 模型权重大小 |
|---|---|---|---|---|
| FP32 | 4 Bytes | 26.8 GB | 52.0 GB | 280.0 GB |
| FP16 / BF16 | 2 Bytes | 13.4 GB | 26.0 GB | 140.0 GB |
| INT8 | 1 Byte | 6.7 GB | 13.0 GB | 70.0 GB |
| INT4 | 0.5 Bytes | 3.35 GB | 6.5 GB | 35.0 GB |
计算公式:显存 = 参数量 x 每参数字节数
FP16 和 BF16 都是 16 位浮点数,占用相同的显存。BF16 的指数位更多(8 位 vs FP16 的 5 位),数值范围更大,训练时更不容易溢出,是目前训练的主流选择。
8.2 训练态显存:四大组成部分
训练一个模型需要的显存远大于存储权重本身。以混合精度训练(BF16 权重 + FP32 优化器)为例:
(1)模型权重:2 Bytes/param
训练时模型以 BF16 存储,即 $2N$ Bytes($N$ 为参数量)。
(2)梯度:2 Bytes/param
梯度与权重形状相同,BF16 存储,$2N$ Bytes。
(3)优化器状态:视优化器而定
这是显存消耗的大头。以最常用的 Adam/AdamW 优化器为例,它需要维护三样东西:
- FP32 参数副本(Master Weights):$4N$ Bytes。为什么需要 FP32 副本?BF16 只有约 3-4 位有效数字,在更新权重时,如果学习率乘以梯度的值很小(比如 1e-5),BF16 的精度不够表示这个微小的增量,更新就会被”四舍五入”掉。FP32 有约 7 位有效数字,能捕捉这些微小更新。
- 一阶动量(First Moment, m):$4N$ Bytes。Adam 维护梯度的指数移动平均,用于估计梯度的均值。
- 二阶动量(Second Moment, v):$4N$ Bytes。Adam 维护梯度平方的指数移动平均,用于估计梯度的方差,实现自适应学习率。
优化器状态合计:$4N + 4N + 4N$ = $12N$ Bytes。
所以人们说”Adam 需要 4x 参数量的显存”,指的就是 $12N / (2N + 2N)$ = 大约 3-4 倍额外显存:优化器状态 $12N$ 本身就是 BF16 权重 $2N$ 的 6 倍。
(4)Activation Memory(激活值显存)
前向传播中间的激活值需要保存下来供反向传播使用。激活值显存与 batch_size 和序列长度成正比,粗略估算公式为:
$$M_{act} \approx s \times b \times d \times L \times k$$
其中 $s$ 是序列长度,$b$ 是 batch_size,$d$ 是 $d_{model}$,$L$ 是层数,$k$ 是一个常数(约 10-14,取决于是否使用 Activation Checkpointing)。
对于 LLaMA-2-7B,seq_len=2048, batch_size=4, 不使用 Activation Checkpointing:
- 粗估 $M_{act} \approx 2048 \times 4 \times 4096 \times 32 \times 12 \times 2$ Bytes (BF16)
- $\approx 2048 \times 4 \times 4096 \times 32 \times 24 = 25.8$ GB
使用 Activation Checkpointing(只保存每层输入,反向时重新计算中间值)可以将激活值显存减少到约 1/3 到 1/5,代价是增加约 33% 的计算量。
训练态显存汇总(LLaMA-2-7B,6.7B 参数):
| 组件 | 计算方式 | 显存 |
|---|---|---|
| BF16 模型权重 | $6.7\text{B} \times 2$ | 13.4 GB |
| BF16 梯度 | $6.7\text{B} \times 2$ | 13.4 GB |
| FP32 Master Weights | $6.7\text{B} \times 4$ | 26.8 GB |
| Adam 一阶动量 (FP32) | $6.7\text{B} \times 4$ | 26.8 GB |
| Adam 二阶动量 (FP32) | $6.7\text{B} \times 4$ | 26.8 GB |
| 静态合计 | 107.2 GB | |
| Activation (估算) | batch=4, seq=2048 | ~25.8 GB |
| 总计 | ~133 GB |
133 GB 已经超过了一张 A100-80GB 的显存。这就是为什么训练 7B 模型看起来规模不大,但实际上需要分布式训练(ZeRO、张量并行等)才能跑起来。
8.3 推理态显存:权重 + KV Cache
推理不需要梯度和优化器状态,显存需求简单很多:
(1)模型权重
以 FP16 推理为例:6.7B x 2 = 13.4 GB
(2)KV Cache
每个 token 在每一层需要缓存 K 和 V 各一个向量:
$$M_{kv} = 2 \times L \times h_{kv} \times d_k \times s \times b \times \text{bytes_per_element}$$
以 LLaMA-2-7B(MHA, $h_{kv} = 32$)为例:
- 单个 token:$2 \times 32 \times 32 \times 128 \times 2 = 524{,}288$ Bytes = 512 KB
- seq_len = 4096:$4096 \times 512$ KB = 2 GB
- batch_size = 16:$16 \times 2$ GB = 32 GB
推理态显存汇总:
| 组件 | LLaMA-2-7B (FP16) | LLaMA-2-7B (INT4) |
|---|---|---|
| 模型权重 | 13.4 GB | 3.35 GB |
| KV Cache (seq=4096, batch=16) | 32 GB | 32 GB (KV 通常仍用 FP16) |
| 其他开销 (框架, buffer) | ~1 GB | ~1 GB |
| 总计 | ~46.4 GB | ~36.4 GB |
注意:即使模型权重量化到 INT4,KV Cache 通常仍然使用 FP16,因为量化 KV Cache 对精度影响较大。近年来有 KV Cache 量化的研究(如 KIVI、KVQuant),可以将 KV Cache 压缩到 INT4 甚至 INT2,但需要额外的校准和精度评估。
8.4 完整规划案例
场景:用单张 A100-80GB 部署 LLaMA-2-13B 进行推理,目标 batch_size=8,最大序列长度 4096,能否装下?
第一步,计算模型权重:
- 13B 参数 $\times$ 2 Bytes (FP16) = 26 GB
第二步,计算 KV Cache:
- LLaMA-2-13B 配置:$L=40$, $h_{kv}=40$ (MHA), $d_k=128$
- 单 token KV:$2 \times 40 \times 40 \times 128 \times 2 = 819{,}200$ Bytes = 800 KB
- seq=4096, batch=8:$4096 \times 800\text{KB} \times 8 = 25.6$ GB
第三步,加上框架开销:
- CUDA 上下文 + 框架 buffer + 临时空间 约 2 GB
第四步,汇总:
- 26 + 25.6 + 2 = 53.6 GB
结论:53.6 GB < 80 GB,可以装下,还有约 26 GB 的余量。
但如果想把 batch_size 提高到 16 呢?
- KV Cache 翻倍:$25.6 \times 2 = 51.2$ GB
- 总计:$26 + 51.2 + 2 = 79.2$ GB
非常接近 80 GB 上限,几乎没有余量,实际运行大概率 OOM。此时的选择:
- 使用 INT8 量化模型权重:$13\text{B} \times 1 = 13$ GB,总计 66.2 GB,可行
- 使用 GQA 模型(如 Mistral)减少 KV Cache
- 使用 KV Cache 量化
- 减少最大序列长度
这就是显存规划的实际价值——不是拍脑袋说”应该能跑”,而是精确地算出每一项开销,找到瓶颈,选择合适的优化策略。
9. 主流开源模型架构对比
最后用一张表对比当前主流开源 LLM 的架构选择,帮助建立全局视野:
| 特性 | LLaMA-2 | LLaMA-3 | Mistral-7B | Qwen-2.5 | DeepSeek-V3 |
|---|---|---|---|---|---|
| Attention 类型 | MHA (7B/13B) / GQA (70B) | GQA | GQA | GQA | MLA (Multi-head Latent Attention) |
| KV 头数 (7B 级) | 32 (MHA) | 8 | 8 | 4 | – (MLA 压缩 KV) |
| 位置编码 | RoPE | RoPE | RoPE | RoPE | RoPE |
| 归一化 | RMSNorm (Pre-Norm) | RMSNorm (Pre-Norm) | RMSNorm (Pre-Norm) | RMSNorm (Pre-Norm) | RMSNorm (Pre-Norm) |
| FFN 类型 | SwiGLU | SwiGLU | SwiGLU | SwiGLU | SwiGLU + MoE |
| 是否 MoE | 否 | 否 | Mixtral 版是 | 部分版本是 | 是 (256 专家, Top-8) |
| 词表大小 | 32000 | 128256 | 32000 | 151936 | 129280 |
| 最大上下文 | 4096 | 8192 (可扩展) | 32768 (Sliding Window) | 32768+ | 128K+ |
| 特殊设计 | – | 更大词表、更长上下文 | Sliding Window Attention | 按需扩展 | MLA + MoE 组合 |
几个关键趋势:
GQA 成为标配。从 LLaMA-2 的 MHA 到后续模型几乎全部采用 GQA,核心驱动力是降低推理时的 KV Cache 显存。KV 头数从 32 一路减少到 8 甚至 4,KV Cache 减少了 4-8 倍。
词表持续扩大。LLaMA-2 的 32000 到 LLaMA-3 的 128256,更大的词表意味着更好的多语言支持和更高的 token 效率(同样的文本用更少的 token 表示),但也增加了 Embedding 层的参数量。
MoE 架构兴起。DeepSeek-V3 和 Mixtral 采用混合专家模型,用更大的总参数量但更少的激活参数量(每个 token 只激活少数专家)来提升效果。这给分布式训练和推理带来了新的工程挑战(Expert Parallelism、负载均衡等)。
MLA 的创新。DeepSeek-V3 的 Multi-head Latent Attention 将 KV 投影到低维潜在空间再恢复,在保持模型能力的同时大幅压缩 KV Cache,是对 GQA 思路的进一步发展。
🎯 自我检验清单
学完本文后,检验以下能力:
- 能解释 Decoder-only 架构为什么在大模型时代胜出,至少说出三个理由
- 能在白板上画出完整的 Decoder Block 数据流图,标注 RMSNorm、Masked MHA、SwiGLU FFN、残差连接的位置和顺序
- 能画出因果掩码矩阵(给定序列长度 $N$),解释上三角 $-\infty$ 的作用
- 能手算 LLaMA-2-7B 和 13B 的参数量(误差不超过 5%),说清 FFN 和 Attention 的参数比例
- 能用 “$6Ns$” 公式估算训练 FLOPs,用 “$2Ns$” 估算推理 FLOPs
- 能计算给定模型在 FP16、INT8、INT4 下的权重显存
- 能算出 Adam 优化器需要多少显存,并解释为什么需要 FP32 Master Weights
- 给定一个模型配置和 GPU 型号,能完成完整的显存规划(权重 + KV Cache 或 权重 + 梯度 + 优化器 + Activation),判断是否能装下
📚 参考资料
论文
- Attention Is All You Need (Vaswani et al., 2017): https://arxiv.org/abs/1706.03762 – Transformer 原始论文
- LLaMA 2: Open Foundation and Fine-Tuned Chat Models (Touvron et al., 2023): https://arxiv.org/abs/2307.09288 – LLaMA-2 技术报告
- Mistral 7B (Jiang et al., 2023): https://arxiv.org/abs/2310.06825 – Mistral 架构与 Sliding Window Attention
- GLU Variants Improve Transformer (Shazeer, 2020): https://arxiv.org/abs/2002.05202 – SwiGLU 激活函数
- Using the Output Embedding to Improve Language Models (Press & Wolf, 2017): https://arxiv.org/abs/1608.05859 – Weight Tying 技术
- Scaling Laws for Neural Language Models (Kaplan et al., 2020): https://arxiv.org/abs/2001.08361 – 模型规模与性能的缩放规律
- ZeRO: Memory Optimizations Toward Training Trillion Parameter Models (Rajbhandari et al., 2020): https://arxiv.org/abs/1910.02054 – ZeRO 显存优化
- DeepSeek-V3 Technical Report (DeepSeek-AI, 2024): https://arxiv.org/abs/2412.19437 – MLA + MoE 架构
教程与博客
- The Illustrated Transformer (Jay Alammar): https://jalammar.github.io/illustrated-transformer/
- Transformer Math 101 (EleutherAI): https://blog.eleuther.ai/transformer-math/ – 参数量和计算量的详细推导
- LLM Training: FSDP vs DeepSpeed (Hugging Face): https://huggingface.co/docs/transformers/main/en/perf_train_gpu_many