分布式训练入门
当模型参数量超越单卡显存极限时,分布式训练就是必经之路。本文从 Transformer 模型基础讲起,系统覆盖数据并行、模型并行(3D 并行)、ZeRO 显存优化、混合精度训练等核心技术,并提供 PyTorch DDP 和 DeepSpeed 的实战代码,帮助从业者建立分布式训练的完整知识体系。
📑 目录
- 1. 模型基础:为什么需要分布式
- 2. 显存占用分析
- 3. 数据并行
- 4. 模型并行(3D 并行)
- 5. ZeRO 显存优化
- 6. 混合精度训练
- 7. 其他显存优化技术
- 8. 训练框架实战
- 9. 并行策略选型指南
- 10. 自我检验清单
- 参考资料
1. 模型基础:为什么需要分布式
1.1 Transformer 架构速览
几乎所有现代大语言模型都基于 Transformer 架构。理解其核心组件是理解分布式训练切分策略的前提。
一个 Transformer 层(Layer)由两个子模块组成:
1 | 输入 x |
Self-Attention:核心计算是 Q、K、V 三个矩阵的投影和 Attention 分数计算。
1 | Attention(Q, K, V) = softmax(Q @ K^T / sqrt(d_k)) @ V |
其中 Q、K、V 由输入 x 通过三个权重矩阵 W_Q、W_K、W_V 投影得到。
FFN(Feed-Forward Network):通常是两层线性变换 + 激活函数。
1 | FFN(x) = W_down * activation(W_up * x) |
W_up 将隐藏维度从 d_model 扩展到 d_ff(通常 4 倍),W_down 再缩回 d_model。
一个完整的 LLM = Embedding 层 + N 个 Transformer 层 + LM Head(输出层)。
1.2 Attention 变种
不同的 Attention 机制对显存和计算量有直接影响:
MHA(Multi-Head Attention)
标准多头注意力。每个 Head 拥有独立的 Q、K、V 投影。
1 | Head 0: Q_0, K_0, V_0 → Attention_0 |
KV Cache 大小 = 2 * n_layers * n_heads * d_head * seq_len * batch_size
MQA(Multi-Query Attention)
所有 Head 共享同一组 K、V,只有 Q 是独立的。KV Cache 减少到 1/n_heads。
GQA(Grouped-Query Attention)
MHA 和 MQA 的折中。将 Head 分成若干组(Group),组内共享 K、V。LLaMA 2 70B、Mistral 等采用。
MLA(Multi-head Latent Attention)
DeepSeek V2 提出。通过低秩压缩将 KV 投影到低维潜空间,进一步减少 KV Cache。
1 | MHA: KV Cache = 2 * n_heads * d_head (每层每 token) |
1.3 FFN 变种:MoE(混合专家模型)
标准 FFN 中所有 token 都经过同一组参数。MoE(Mixture of Experts)将 FFN 替换为多个”专家”网络,每个 token 只激活其中 Top-K 个专家。
1 | 标准 FFN: |
关键特性:
- 总参数量大:例如 8 个专家意味着 FFN 参数量 x8
- 单 token 计算量小:每个 token 只激活 Top-2 个专家,计算量接近普通 FFN
- 对分布式有特殊要求:不同专家可以放在不同 GPU 上(Expert Parallelism)
DeepSeek V3 使用了 256 个细粒度专家 + 共享专家的 MoE 架构,总参数 671B 但每 token 仅激活 37B。
2. 显存占用分析
在讨论分布式策略之前,必须先搞清楚训练时的显存都花在了哪里。
2.1 显存占用的四大部分
以一个参数量为 P(单位:元素数)的模型为例:
| 项目 | FP32 训练 | 混合精度训练 (FP16/BF16) |
|---|---|---|
| 模型参数 | 4P 字节 | 2P (FP16) + 4P (FP32 主副本) |
| 梯度 | 4P 字节 | 2P (FP16) 或 4P (FP32) |
| 优化器状态 (Adam) | 8P 字节 (m + v) | 8P 字节 (FP32 的 m + v) |
| 激活值 | 取决于 batch/seq/层数 | 取决于 batch/seq/层数 |
混合精度下 Adam 训练的显存公式:
1 | 模型参数: 2P (FP16 参数) |
2.2 实际案例估算
| 模型 | 参数量 P | 参数+梯度+优化器 | 估算激活值 (batch=1, seq=2048) | 总计 |
|---|---|---|---|---|
| LLaMA-7B | 7B | ~112 GB | ~8 GB | ~120 GB |
| LLaMA-13B | 13B | ~208 GB | ~15 GB | ~223 GB |
| LLaMA-70B | 70B | ~1120 GB | ~80 GB | ~1200 GB |
对比单卡显存(A100 80GB、H100 80GB)可以看到:
- 7B:单卡放不下完整训练状态,需要至少 2 卡
- 70B:需要至少 16 张 A100
- 更大的模型需要上百甚至上千张卡
这就是分布式训练的动机:用多张 GPU 的显存和算力,共同完成单卡无法胜任的训练任务。
2.3 激活值显存
激活值(Activations)是前向传播过程中保存的中间结果,反向传播时用来计算梯度。激活值的显存与 batch size、序列长度、隐藏维度、层数 都相关:
1 | 每层 Attention 激活 ≈ 2 * batch * seq * d_model * (1 + seq/d_model) |
激活值通常是显存的”弹性部分”——可以通过 Activation Checkpointing 用计算换显存来压缩。
3. 数据并行
数据并行是最基础、最直观的分布式训练方式:每张 GPU 持有完整的模型副本,各自处理不同的数据批次,然后同步梯度。
白话理解:数据并行就像多个学生各拿一份完整的试卷(模型副本),每人做不同的题目(不同数据批次),做完后大家”对答案”求平均(梯度同步)——人多力量大,做题速度成倍提升。
3.1 DP(DataParallel)
PyTorch 最早期的数据并行实现,单进程多线程:
1 | model = nn.DataParallel(model) |
工作流程:
- GPU 0(主卡)广播模型参数到所有 GPU
- 数据均匀切分到各 GPU
- 各 GPU 独立前向传播
- 梯度汇总到 GPU 0
- GPU 0 更新参数
缺点:
- GPU 0 是瓶颈:梯度汇总和参数更新都在主卡,显存和计算不均衡
- GIL 限制:Python 全局解释器锁导致多线程效率差
- 通信效率低:梯度先汇总到主卡再广播,不如 AllReduce 高效
结论:DP 已被淘汰,不要在新项目中使用。
3.2 DDP(DistributedDataParallel)
DDP 是 PyTorch 推荐的数据并行方案,每个 GPU 运行一个独立的进程:
1 | GPU 0 (进程 0): 模型副本 → 数据 batch_0 → 梯度_0 ─┐ |
核心机制:
- 每个进程持有完整的模型副本和优化器
- 前向传播完全独立
- 反向传播中使用 AllReduce 同步梯度(边算梯度边通信,重叠计算与通信)
- 各进程独立用平均梯度更新参数(更新后参数完全一致)
完整的 DDP 代码示例:
1 | import torch |
启动命令:
1 | # 单机 4 卡 |
DDP 的梯度同步细节:
DDP 不是等所有梯度算完再做一次大的 AllReduce,而是将参数分成多个 Bucket,每个 Bucket 的梯度算完就立刻启动 AllReduce,与后续层的反向传播重叠:
1 | 时间 → |
3.3 FSDP(Fully Sharded Data Parallel)
FSDP 是 PyTorch 原生的 ZeRO-3 实现。它在 DDP 基础上更进一步:不仅同步梯度,还将模型参数、梯度、优化器状态都切分到各 GPU。
1 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
FSDP 的核心思想将在下文 ZeRO 章节详细展开。
3.4 数据并行的通信量
DDP 中,每个训练步骤需要一次 AllReduce,通信量 = 模型参数量:
1 | 每 GPU 的 AllReduce 通信量 ≈ 2 * P * sizeof(dtype) |
4. 模型并行(3D 并行)
当模型太大、单卡放不下时,需要将模型本身切分到多张 GPU 上。3D 并行是指 张量并行(TP)+ 流水线并行(PP)+ 数据并行(DP) 的组合。
4.1 张量并行(Tensor Parallelism, TP)
张量并行将单个矩阵乘法切分到多张 GPU 上并行计算。
白话理解:把一个大矩阵切开,每个 GPU 算一块,最后拼起来——像多人合力抬一块大石板,每人抬一角,谁也不用独自承受全部重量。
以 Megatron-LM 的方案为例:
列并行(Column Parallel)
将权重矩阵 W 沿列维度切分:
1 | W W_0 W_1 |
行并行(Row Parallel)
将权重矩阵 W 沿行维度切分:
1 | W W_0 |
Transformer 层中的 TP 切分
Megatron-LM 巧妙地将列并行和行并行配对使用,减少通信次数:
1 | Self-Attention: |
每个 Transformer 层的 TP 通信:
1 | 前向传播: 2 次 AllReduce (Attention 1 次 + FFN 1 次) |
为什么 TP 通常限制在单机内:
以 80 层模型为例,每步训练需要 80 * 4 = 320 次 AllReduce。如果走 NVLink(900 GB/s),每次 AllReduce 耗时很短;但如果走跨机 IB 网络(~50 GB/s),延迟增加 18 倍,训练速度断崖式下降。
4.2 流水线并行(Pipeline Parallelism, PP)
流水线并行将模型的不同层分配到不同 GPU 上:
白话理解:把模型按层切分到不同 GPU,数据像流水线上的产品一样依次经过每个工位——第一个工人装轮子(前几层),第二个工人装车门(中间层),第三个工人喷漆(后几层),产品依次流过每个工位完成加工。
1 | GPU 0: [Embedding + Layer 0-19] |
数据在 GPU 之间通过 点对点(Send/Recv) 传递:
1 | GPU 0 计算完 Layer 0-19 → 发送激活值给 GPU 1 |
流水线气泡(Pipeline Bubble)
朴素的流水线并行存在严重的空闲等待问题:
1 | 时间 → |
解决方案:Micro-batch
将一个 mini-batch 拆分为 M 个 micro-batch,让流水线更充实:
1 | 时间 → (4 个 micro-batch, m1-m4) |
GPipe vs 1F1B
| 方案 | 调度策略 | 显存 | 气泡 |
|---|---|---|---|
| GPipe | 所有 micro-batch 前向完 → 所有反向 | 高(保存所有 micro-batch 激活) | (P-1)/(P-1+M) |
| 1F1B (Megatron) | 前向和反向交替执行 | 低(及时释放激活) | 相同 |
1F1B(One Forward One Backward)的调度示意:
1 | GPU 0: [F1][F2][F3][F4][B1][F5][B2][F6][B3]... |
Warmup 阶段连续前向,然后进入稳态后每次 1 前向 + 1 反向交替。
4.3 序列并行(Sequence Parallelism, SP)
序列并行沿序列维度切分激活值,主要目的是减少激活值的显存占用。它通常与 TP 配合使用。
在 TP 中,LayerNorm 和 Dropout 等操作的输入是完整的(没有被切分),它们的激活值占用显存。SP 将这些操作的输入沿序列维度切分,每个 GPU 只持有 1/N 的激活值。
1 | 标准 TP: |
通信原语发生变化(AllReduce → AllGather + ReduceScatter),但总通信量不变。
4.4 3D 并行的组合
实际大模型训练通常同时使用 TP + PP + DP。以 128 张 GPU 训练为例:
1 | 128 GPUs = TP_size=8 × PP_size=4 × DP_size=4 |
1 | 数据并行组 0 数据并行组 1 |
选择原则:
- TP:放在单机内(NVLink 带宽高),通信最密集
- PP:可以跨机,通信量相对小(只传激活值)
- DP:跨机,AllReduce 梯度
5. ZeRO 显存优化
ZeRO(Zero Redundancy Optimizer)是微软 DeepSpeed 提出的显存优化系列,核心思想:DDP 中每个 GPU 持有完整的模型参数、梯度、优化器状态是巨大的浪费,可以将它们切分到各 GPU 上。
白话理解:正常情况下每张卡都存一份完整的”仓库”(优化器状态、梯度、参数),ZeRO 的思路是大家分工,每人只保管一部分,需要时再问别人借——用通信换显存。就像合租的室友不必每人买一套完整的工具箱,大家各买几件,要用时互相借就行。
5.1 冗余分析
在标准 DDP 中,N 个 GPU 的总显存占用:
1 | 总浪费 = N 份 × (参数 + 梯度 + 优化器状态) |
8 卡 DDP 训练 7B 模型:总占用 8 × 112GB = 896 GB,其中 784 GB 是冗余!
5.2 ZeRO-1:切分优化器状态
每个 GPU 只保存 1/N 的优化器状态(Adam 的 m 和 v):
1 | GPU 0 GPU 1 GPU 2 GPU 3 |
额外通信:更新参数后需要 AllGather 把更新后的参数广播给所有 GPU。
5.3 ZeRO-2:切分梯度 + 优化器状态
在 ZeRO-1 基础上,梯度也切分到各 GPU:
1 | GPU 0 GPU 1 GPU 2 GPU 3 |
通信变化:反向传播中用 ReduceScatter 替代 AllReduce,每个 GPU 只得到自己负责的那部分梯度。
5.4 ZeRO-3:切分一切
参数、梯度、优化器状态全部切分:
1 | GPU 0 GPU 1 GPU 2 GPU 3 |
代价:前向和反向传播中都需要 AllGather 来临时收集完整的层参数,用完即释放。
1 | 前向传播某一层: |
5.5 ZeRO 三级对比
| 参数 | 梯度 | 优化器状态 | 每 GPU 显存 | 额外通信 | |
|---|---|---|---|---|---|
| DDP | 完整 | 完整 | 完整 | 16P | AllReduce (梯度) |
| ZeRO-1 | 完整 | 完整 | 1/N | 4P + 4P + 8P/N | + AllGather (参数更新) |
| ZeRO-2 | 完整 | 1/N | 1/N | 4P + 12P/N | ReduceScatter + AllGather |
| ZeRO-3 | 1/N | 1/N | 1/N | 16P/N | 前向/反向各一次 AllGather |
通信量对比(每个训练步):
| 方案 | 通信量 (每 GPU) | 通信次数 |
|---|---|---|
| DDP | 2P | 1 (AllReduce) |
| ZeRO-1 | 2P + P = 3P | AllReduce + AllGather |
| ZeRO-2 | 2P + P = 3P | ReduceScatter + AllGather |
| ZeRO-3 | 2P + 3P = 5P | 多次 AllGather + ReduceScatter |
ZeRO-3 通信量是 DDP 的 ~2.5 倍,但显存节省巨大。在 GPU 数量很多时(如 64+),显存节省远比通信开销增加更有价值。
6. 混合精度训练
6.1 为什么用混合精度
- 节省显存:FP16 参数比 FP32 小一半
- 加速计算:Tensor Core 的 FP16/BF16 算力是 FP32 的 ~15 倍(H100)
- 减少通信量:梯度同步的数据量减半
6.2 FP16 vs BF16
| FP16 | BF16 | |
|---|---|---|
| 符号位 | 1 | 1 |
| 指数位 | 5 | 8 |
| 尾数位 | 10 | 7 |
| 表示范围 | ±65504 | ±3.4×10^38(与 FP32 相同) |
| 精度 | 更高(尾数位多) | 更低 |
| 溢出风险 | 高(范围小) | 低(范围大) |
| Tensor Core 支持 | Volta+ | Ampere+ |
关键区别:
- FP16 容易溢出:指数位只有 5 位,表示范围小,梯度容易超出范围(gradient overflow)
- BF16 几乎不溢出:指数位与 FP32 相同,表示范围一致
- BF16 精度略低:尾数位少 3 位,但实践中对训练影响不大
结论:如果硬件支持(Ampere+),优先使用 BF16,不需要 Loss Scaling。
6.3 混合精度训练流程
1 | ┌─────── FP32 主副本(Master Weights)───────┐ |
Loss Scaling(损失缩放)
FP16 训练中,梯度值可能非常小(如 1e-7),超出 FP16 的表示精度变为 0(gradient underflow)。Loss Scaling 通过放大 loss 来放大梯度,更新时再缩回来:
白话理解:FP16 的数值范围小,梯度太小会变成 0(下溢)。Loss Scaling 先把 loss 放大,算完梯度再缩回来——就像用放大镜看蚂蚁,看清楚细节后再缩回原始比例记录。
1 | # 手动 Loss Scaling |
BF16 不需要 Loss Scaling,因为其表示范围与 FP32 相同,梯度不会 underflow。
6.4 PyTorch 混合精度实战
1 | import torch |
7. 其他显存优化技术
7.1 梯度累积(Gradient Accumulation)
在显存不足以跑大 batch 时,可以将一个大 batch 拆成多个小 micro-batch,累积梯度后再更新:
白话理解:显存不够一次算大 batch?那就分几次算小 batch,把梯度攒起来再一起更新——就像攒零钱换整钱,每次攒一点,攒够了再花。
1 | accumulation_steps = 4 # 等效 batch_size = micro_batch_size * 4 |
效果:
- 显存占用 = 小 micro-batch 的显存
- 训练效果 ≈ 大 batch(梯度累积等价于大 batch 求平均)
- 代价:训练步骤变慢(每个有效 step 需要多次前向/反向)
7.2 Activation Checkpointing(激活重计算)
正常训练中,前向传播的所有中间激活值都要保存到反向传播时用来算梯度。Activation Checkpointing 的思路:不保存中间激活值,反向传播时重新计算。
1 | 正常训练: |
1 | # PyTorch 原生 Activation Checkpointing |
7.3 显存优化技术对比
| 技术 | 原理 | 节省 | 代价 |
|---|---|---|---|
| 混合精度 (FP16/BF16) | 低精度存储和计算 | 参数/激活显存减半 | 精度风险(FP16 需 Loss Scaling) |
| 梯度累积 | 小 batch 累积梯度 | 激活值显存降低 | 训练步骤变慢 |
| Activation Checkpointing | 丢弃激活值,反向时重算 | 激活值显存降到 O(1) | ~33% 额外计算 |
| ZeRO-1 | 切分优化器状态 | 优化器显存 /N | 少量额外通信 |
| ZeRO-2 | 切分梯度 + 优化器 | 梯度+优化器 /N | 额外通信 |
| ZeRO-3 | 切分一切 | 总显存 /N | 通信量 x2.5 |
8. 训练框架实战
8.1 DeepSpeed
DeepSpeed 是微软开发的分布式训练库,核心能力是 ZeRO 系列。使用非常简单——只需要一个 JSON 配置文件和少量代码修改。
DeepSpeed ZeRO-2 配置
1 | { |
DeepSpeed ZeRO-3 配置
1 | { |
代码集成
1 | import deepspeed |
启动方式
1 | # 单机 4 卡 |
8.2 Megatron-LM
Megatron-LM 是 NVIDIA 开发的大模型训练框架,核心贡献是 张量并行(TP)和流水线并行(PP) 的高效实现。
1 | # Megatron-LM 典型启动命令 |
8.3 HuggingFace Accelerate + DeepSpeed
对于使用 HuggingFace 生态的用户,Accelerate 提供了最简单的集成方式:
1 | from accelerate import Accelerator |
配置文件通过 accelerate config 交互式生成,支持 DDP / FSDP / DeepSpeed。
9. 并行策略选型指南
9.1 决策流程
1 | 模型能放进单卡吗? |
9.2 常见配置参考
| 模型规模 | GPU 数量 | 推荐方案 |
|---|---|---|
| 1-7B | 1-8 | DDP + BF16 + Activation Checkpointing |
| 7-13B | 4-16 | ZeRO-2/3 + BF16 |
| 13-70B | 16-64 | ZeRO-3 或 TP=8 + PP=2-4 + DP |
| 70B+ | 64-256+ | TP=8 + PP=4-16 + DP + ZeRO-1 |
| MoE (>100B) | 64-512+ | TP + PP + EP(Expert Parallel)+ DP |
9.3 权衡思维
分布式训练没有银弹。每种技术都是在 “计算、通信、显存” 三角中做取舍:
| 技术 | 牺牲 | 换取 |
|---|---|---|
| ZeRO-3 | 通信量 x2.5 | 显存 /N |
| Activation Checkpointing | 33% 额外计算 | 激活值显存降至 O(1) |
| 流水线并行 | 流水线气泡(空闲时间) | 模型跨机分布 |
| 张量并行 | 高频 AllReduce 通信 | 单层参数跨卡分布 |
| 混合精度 | 精度风险 | 显存减半 + 算力翻倍 |
| 梯度累积 | 训练步骤变慢 | 等效大 batch |
10. 自我检验清单
完成本文学习后,你应该能够:
- 能用 PyTorch DDP 将单卡训练脚本改造为多卡分布式训练(含
torchrun启动) - 能用公式计算给定模型的参数量、优化器状态、梯度所需的显存占用(16P 公式)
- 能解释 ZeRO-1/2/3 各切分了什么(优化器状态/梯度/参数),通信量如何变化
- 能画出 TP + PP 的 3D 并行拓扑图,标注通信位置与通信量
- 能解释张量并行中列并行和行并行的切分方式及它们如何配对减少通信
- 能解释流水线并行的气泡问题以及 micro-batch 如何缓解
- 能解释混合精度训练为什么需要 Loss Scaling,以及 BF16 vs FP16 的差异
- 能配置 DeepSpeed ZeRO Stage 2/3 并跑通一个训练任务
- 能解释梯度累积如何在有限显存下模拟更大 Batch Size
- 能解释 MHA、MQA、GQA 的区别对 KV Cache 大小的影响
- 能根据模型规模和 GPU 数量给出合理的并行策略方案
📚 参考资料
- Attention Is All You Need
- Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism
- Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM
- ZeRO: Memory Optimizations Toward Training Trillion Parameter Models
- DeepSpeed - GitHub
- DeepSpeed Documentation
- PyTorch DDP Tutorial
- PyTorch FSDP Tutorial
- DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model
- DeepSeekMoE: Towards Ultimate Expert Specialization in Mixture-of-Experts Language Models
- GPipe: Easy Scaling with Micro-Batch Pipeline Parallelism
- Reducing Activation Recomputation in Large Transformer Models (Sequence Parallelism)
- HuggingFace Accelerate Documentation
- 苏剑林:缓存与效果的极限拉扯:从MHA、MQA、GQA到MLA