分布式训练入门

当模型参数量超越单卡显存极限时,分布式训练就是必经之路。本文从 Transformer 模型基础讲起,系统覆盖数据并行、模型并行(3D 并行)、ZeRO 显存优化、混合精度训练等核心技术,并提供 PyTorch DDP 和 DeepSpeed 的实战代码,帮助从业者建立分布式训练的完整知识体系。

📑 目录


1. 模型基础:为什么需要分布式

1.1 Transformer 架构速览

几乎所有现代大语言模型都基于 Transformer 架构。理解其核心组件是理解分布式训练切分策略的前提。

一个 Transformer 层(Layer)由两个子模块组成:

1
2
3
4
5
6
7
输入 x

[Self-Attention] ── Q, K, V 投影 → Attention 计算 → 输出投影
↓ (残差连接 + LayerNorm)
[FFN] ── 上投影 → 激活函数 → 下投影
↓ (残差连接 + LayerNorm)
输出

Self-Attention:核心计算是 Q、K、V 三个矩阵的投影和 Attention 分数计算。

1
Attention(Q, K, V) = softmax(Q @ K^T / sqrt(d_k)) @ V

其中 Q、K、V 由输入 x 通过三个权重矩阵 W_Q、W_K、W_V 投影得到。

FFN(Feed-Forward Network):通常是两层线性变换 + 激活函数。

1
FFN(x) = W_down * activation(W_up * x)

W_up 将隐藏维度从 d_model 扩展到 d_ff(通常 4 倍),W_down 再缩回 d_model。

一个完整的 LLM = Embedding 层 + N 个 Transformer 层 + LM Head(输出层)。

1.2 Attention 变种

不同的 Attention 机制对显存和计算量有直接影响:

MHA(Multi-Head Attention)

标准多头注意力。每个 Head 拥有独立的 Q、K、V 投影。

1
2
3
4
5
Head 0: Q_0, K_0, V_0  → Attention_0
Head 1: Q_1, K_1, V_1 → Attention_1
...
Head h: Q_h, K_h, V_h → Attention_h
→ Concat → 输出投影

KV Cache 大小 = 2 * n_layers * n_heads * d_head * seq_len * batch_size

MQA(Multi-Query Attention)

所有 Head 共享同一组 K、V,只有 Q 是独立的。KV Cache 减少到 1/n_heads。

GQA(Grouped-Query Attention)

MHA 和 MQA 的折中。将 Head 分成若干组(Group),组内共享 K、V。LLaMA 2 70B、Mistral 等采用。

MLA(Multi-head Latent Attention)

DeepSeek V2 提出。通过低秩压缩将 KV 投影到低维潜空间,进一步减少 KV Cache。

1
2
3
4
MHA:  KV Cache = 2 * n_heads * d_head       (每层每 token)
MQA: KV Cache = 2 * d_head (共享 KV)
GQA: KV Cache = 2 * n_groups * d_head (折中)
MLA: KV Cache = d_compressed (低秩压缩)

1.3 FFN 变种:MoE(混合专家模型)

标准 FFN 中所有 token 都经过同一组参数。MoE(Mixture of Experts)将 FFN 替换为多个”专家”网络,每个 token 只激活其中 Top-K 个专家。

1
2
3
4
5
6
7
8
9
10
标准 FFN:
x → [FFN] → y # 所有 token 用同一组权重

MoE FFN:
x → [Router/Gate] → 选择 Top-K 专家
├→ Expert 0: FFN_0(x) # 只有被选中的专家参与计算
├→ Expert 1: FFN_1(x)
├→ ...
└→ Expert N: FFN_N(x)
→ 加权求和 → y

关键特性

  • 总参数量大:例如 8 个专家意味着 FFN 参数量 x8
  • 单 token 计算量小:每个 token 只激活 Top-2 个专家,计算量接近普通 FFN
  • 对分布式有特殊要求:不同专家可以放在不同 GPU 上(Expert Parallelism)

DeepSeek V3 使用了 256 个细粒度专家 + 共享专家的 MoE 架构,总参数 671B 但每 token 仅激活 37B。


2. 显存占用分析

在讨论分布式策略之前,必须先搞清楚训练时的显存都花在了哪里。

2.1 显存占用的四大部分

以一个参数量为 P(单位:元素数)的模型为例:

项目 FP32 训练 混合精度训练 (FP16/BF16)
模型参数 4P 字节 2P (FP16) + 4P (FP32 主副本)
梯度 4P 字节 2P (FP16) 或 4P (FP32)
优化器状态 (Adam) 8P 字节 (m + v) 8P 字节 (FP32 的 m + v)
激活值 取决于 batch/seq/层数 取决于 batch/seq/层数

混合精度下 Adam 训练的显存公式

1
2
3
4
5
6
7
模型参数:        2P   (FP16 参数)
FP32 主副本: 4P (用于梯度更新的精确副本)
梯度: 2P (FP16 梯度)
优化器 m (一阶矩): 4P (FP32)
优化器 v (二阶矩): 4P (FP32)
────────────────────────
总计: 16P 字节(不含激活值)

2.2 实际案例估算

模型 参数量 P 参数+梯度+优化器 估算激活值 (batch=1, seq=2048) 总计
LLaMA-7B 7B ~112 GB ~8 GB ~120 GB
LLaMA-13B 13B ~208 GB ~15 GB ~223 GB
LLaMA-70B 70B ~1120 GB ~80 GB ~1200 GB

对比单卡显存(A100 80GB、H100 80GB)可以看到:

  • 7B:单卡放不下完整训练状态,需要至少 2 卡
  • 70B:需要至少 16 张 A100
  • 更大的模型需要上百甚至上千张卡

这就是分布式训练的动机:用多张 GPU 的显存和算力,共同完成单卡无法胜任的训练任务。

2.3 激活值显存

激活值(Activations)是前向传播过程中保存的中间结果,反向传播时用来计算梯度。激活值的显存与 batch size、序列长度、隐藏维度、层数 都相关:

1
2
3
4
每层 Attention 激活 ≈ 2 * batch * seq * d_model * (1 + seq/d_model)
每层 FFN 激活 ≈ 2 * batch * seq * d_ff

总激活 ≈ n_layers * (Attention 激活 + FFN 激活)

激活值通常是显存的”弹性部分”——可以通过 Activation Checkpointing 用计算换显存来压缩。


3. 数据并行

数据并行是最基础、最直观的分布式训练方式:每张 GPU 持有完整的模型副本,各自处理不同的数据批次,然后同步梯度。

白话理解:数据并行就像多个学生各拿一份完整的试卷(模型副本),每人做不同的题目(不同数据批次),做完后大家”对答案”求平均(梯度同步)——人多力量大,做题速度成倍提升。

3.1 DP(DataParallel)

PyTorch 最早期的数据并行实现,单进程多线程:

1
model = nn.DataParallel(model)

工作流程

  1. GPU 0(主卡)广播模型参数到所有 GPU
  2. 数据均匀切分到各 GPU
  3. 各 GPU 独立前向传播
  4. 梯度汇总到 GPU 0
  5. GPU 0 更新参数

缺点

  • GPU 0 是瓶颈:梯度汇总和参数更新都在主卡,显存和计算不均衡
  • GIL 限制:Python 全局解释器锁导致多线程效率差
  • 通信效率低:梯度先汇总到主卡再广播,不如 AllReduce 高效

结论:DP 已被淘汰,不要在新项目中使用。

3.2 DDP(DistributedDataParallel)

DDP 是 PyTorch 推荐的数据并行方案,每个 GPU 运行一个独立的进程:

1
2
3
4
GPU 0 (进程 0): 模型副本 → 数据 batch_0 → 梯度_0 ─┐
GPU 1 (进程 1): 模型副本 → 数据 batch_1 → 梯度_1 ─┤── AllReduce ──→ 平均梯度
GPU 2 (进程 2): 模型副本 → 数据 batch_2 → 梯度_2 ─┤ ↓
GPU 3 (进程 3): 模型副本 → 数据 batch_3 → 梯度_3 ─┘ 各 GPU 独立更新参数

核心机制

  • 每个进程持有完整的模型副本和优化器
  • 前向传播完全独立
  • 反向传播中使用 AllReduce 同步梯度(边算梯度边通信,重叠计算与通信)
  • 各进程独立用平均梯度更新参数(更新后参数完全一致)

完整的 DDP 代码示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler

def setup(rank, world_size):
"""初始化分布式环境"""
dist.init_process_group(
backend='nccl', # GPU 通信用 NCCL
init_method='env://', # 通过环境变量传递地址
rank=rank,
world_size=world_size
)
torch.cuda.set_device(rank)

def cleanup():
dist.destroy_process_group()

def train(rank, world_size):
setup(rank, world_size)

# 创建模型并包装为 DDP
model = MyModel().to(rank)
model = DDP(model, device_ids=[rank])

# 使用 DistributedSampler 确保每个 GPU 拿到不同的数据
dataset = MyDataset()
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
dataloader = DataLoader(dataset, batch_size=32, sampler=sampler)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

for epoch in range(num_epochs):
sampler.set_epoch(epoch) # 每个 epoch 重新 shuffle
for batch in dataloader:
inputs, labels = batch
inputs = inputs.to(rank)
labels = labels.to(rank)

outputs = model(inputs)
loss = criterion(outputs, labels)

optimizer.zero_grad()
loss.backward() # DDP 自动在反向传播中 AllReduce 梯度
optimizer.step()

cleanup()

# 启动方式
# torchrun --nproc_per_node=4 train.py

启动命令

1
2
3
4
5
6
7
8
9
10
# 单机 4 卡
torchrun --nproc_per_node=4 train.py

# 2 机 8 卡(每机 4 卡)
# 机器 1:
torchrun --nproc_per_node=4 --nnodes=2 --node_rank=0 \
--master_addr=192.168.1.1 --master_port=29500 train.py
# 机器 2:
torchrun --nproc_per_node=4 --nnodes=2 --node_rank=1 \
--master_addr=192.168.1.1 --master_port=29500 train.py

DDP 的梯度同步细节

DDP 不是等所有梯度算完再做一次大的 AllReduce,而是将参数分成多个 Bucket,每个 Bucket 的梯度算完就立刻启动 AllReduce,与后续层的反向传播重叠:

1
2
3
4
时间 →
反向传播: [Layer N 梯度][Layer N-1 梯度][Layer N-2 梯度]...
AllReduce: [Bucket 2 通信 ][Bucket 1 通信 ]...
↑ 重叠:通信和计算并行进行

3.3 FSDP(Fully Sharded Data Parallel)

FSDP 是 PyTorch 原生的 ZeRO-3 实现。它在 DDP 基础上更进一步:不仅同步梯度,还将模型参数、梯度、优化器状态都切分到各 GPU

1
2
3
4
5
6
7
8
9
10
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy

model = MyModel().to(rank)
model = FSDP(
model,
sharding_strategy=ShardingStrategy.FULL_SHARD, # 等同 ZeRO-3
# ShardingStrategy.SHARD_GRAD_OP → 等同 ZeRO-2
# ShardingStrategy.NO_SHARD → 等同 DDP
)

FSDP 的核心思想将在下文 ZeRO 章节详细展开。

3.4 数据并行的通信量

DDP 中,每个训练步骤需要一次 AllReduce,通信量 = 模型参数量:

1
2
3
4
5
6
7
每 GPU 的 AllReduce 通信量 ≈ 2 * P * sizeof(dtype)
(Ring AllReduce: 发送 + 接收各一次)

示例:7B 模型 FP16
通信量 = 2 * 7B * 2 bytes = 28 GB
8 卡 NVLink (900 GB/s): ~31 ms
跨机 IB 400Gb (50 GB/s): ~560 ms

4. 模型并行(3D 并行)

当模型太大、单卡放不下时,需要将模型本身切分到多张 GPU 上。3D 并行是指 张量并行(TP)+ 流水线并行(PP)+ 数据并行(DP) 的组合。

4.1 张量并行(Tensor Parallelism, TP)

张量并行将单个矩阵乘法切分到多张 GPU 上并行计算。

白话理解:把一个大矩阵切开,每个 GPU 算一块,最后拼起来——像多人合力抬一块大石板,每人抬一角,谁也不用独自承受全部重量。

以 Megatron-LM 的方案为例:

列并行(Column Parallel)

将权重矩阵 W 沿列维度切分:

1
2
3
4
5
6
7
              W                              W_0        W_1
[d x 4d] → [d x 2d] [d x 2d]
GPU 0 GPU 1

x @ W = x @ [W_0 | W_1] = [x @ W_0 | x @ W_1]

每个 GPU 计算一半的输出列,最后拼接(AllGather)

行并行(Row Parallel)

将权重矩阵 W 沿行维度切分:

1
2
3
4
5
6
7
              W                              W_0
[4d x d] → [2d x d] GPU 0
W_1
[2d x d] GPU 1

输入也需要切分: x = [x_0 | x_1]
x @ W = x_0 @ W_0 + x_1 @ W_1 → AllReduce 求和

Transformer 层中的 TP 切分

Megatron-LM 巧妙地将列并行和行并行配对使用,减少通信次数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
Self-Attention:
输入 x (所有 GPU 都有完整副本)

Q, K, V 投影 [列并行] → 每 GPU 算部分 Head → Attention

输出投影 [行并行] → AllReduce → 完整输出
↓ ← 整个 Attention 只需 1 次 AllReduce

FFN:
输入 x

上投影 W_up [列并行] → 激活函数

下投影 W_down [行并行] → AllReduce → 完整输出
↓ ← 整个 FFN 只需 1 次 AllReduce

每个 Transformer 层的 TP 通信

1
2
3
前向传播: 2 次 AllReduce (Attention 1 次 + FFN 1 次)
反向传播: 2 次 AllReduce
总计每层: 4 次 AllReduce

为什么 TP 通常限制在单机内

以 80 层模型为例,每步训练需要 80 * 4 = 320 次 AllReduce。如果走 NVLink(900 GB/s),每次 AllReduce 耗时很短;但如果走跨机 IB 网络(~50 GB/s),延迟增加 18 倍,训练速度断崖式下降。

4.2 流水线并行(Pipeline Parallelism, PP)

流水线并行将模型的不同层分配到不同 GPU 上:

白话理解:把模型按层切分到不同 GPU,数据像流水线上的产品一样依次经过每个工位——第一个工人装轮子(前几层),第二个工人装车门(中间层),第三个工人喷漆(后几层),产品依次流过每个工位完成加工。

1
2
3
4
GPU 0: [Embedding + Layer 0-19]
GPU 1: [Layer 20-39]
GPU 2: [Layer 40-59]
GPU 3: [Layer 60-79 + LM Head]

数据在 GPU 之间通过 点对点(Send/Recv) 传递:

1
2
3
4
GPU 0 计算完 Layer 0-19 → 发送激活值给 GPU 1
GPU 1 计算 Layer 20-39 → 发送激活值给 GPU 2
GPU 2 计算 Layer 40-59 → 发送激活值给 GPU 3
GPU 3 计算 Layer 60-79 → 计算 loss → 反向传播沿反方向传

流水线气泡(Pipeline Bubble)

朴素的流水线并行存在严重的空闲等待问题:

1
2
3
4
5
6
7
8
时间 →
GPU 0: [F ][ ][ ][ ][ ][ ][B ][更新]
GPU 1: [ ][F ][ ][ ][ ][B ][ ][更新]
GPU 2: [ ][ ][F ][ ][B ][ ][ ][更新]
GPU 3: [ ][ ][ ][F ][B ][ ][ ][更新]
↑ 大量空闲时间(气泡)

气泡占比 = (P-1) / P ,P=4 时气泡占 75%!

解决方案:Micro-batch

将一个 mini-batch 拆分为 M 个 micro-batch,让流水线更充实:

1
2
3
4
5
6
7
8
9
时间 →  (4 个 micro-batch, m1-m4)
GPU 0: [F1][F2][F3][F4][ ][B4][B3][B2][B1][更新]
GPU 1: [ ][F1][F2][F3][F4][B4][B3][B2][B1][更新]
GPU 2: [ ][ ][F1][F2][F3][F4][B4][B3][B2][更新]
GPU 3: [ ][ ][ ][F1][F2][F3][F4][B3][B2][更新]

气泡占比 = (P-1) / (P-1+M)
P=4, M=4: 气泡 = 3/7 ≈ 43%
P=4, M=32: 气泡 = 3/35 ≈ 8.6% ← 可接受

GPipe vs 1F1B

方案 调度策略 显存 气泡
GPipe 所有 micro-batch 前向完 → 所有反向 高(保存所有 micro-batch 激活) (P-1)/(P-1+M)
1F1B (Megatron) 前向和反向交替执行 低(及时释放激活) 相同

1F1B(One Forward One Backward)的调度示意:

1
GPU 0: [F1][F2][F3][F4][B1][F5][B2][F6][B3]...

Warmup 阶段连续前向,然后进入稳态后每次 1 前向 + 1 反向交替。

4.3 序列并行(Sequence Parallelism, SP)

序列并行沿序列维度切分激活值,主要目的是减少激活值的显存占用。它通常与 TP 配合使用。

在 TP 中,LayerNorm 和 Dropout 等操作的输入是完整的(没有被切分),它们的激活值占用显存。SP 将这些操作的输入沿序列维度切分,每个 GPU 只持有 1/N 的激活值。

1
2
3
4
5
6
7
标准 TP:
LayerNorm(完整 x) → TP 矩阵乘法 → AllReduce → LayerNorm(完整 x)
↑ LayerNorm 的激活值在每个 GPU 上都是完整的

TP + SP:
LayerNorm(x 的 1/N) → AllGather → TP 矩阵乘法 → ReduceScatter → LayerNorm(x 的 1/N)
↑ LayerNorm 的激活值在每个 GPU 上只有 1/N

通信原语发生变化(AllReduce → AllGather + ReduceScatter),但总通信量不变。

4.4 3D 并行的组合

实际大模型训练通常同时使用 TP + PP + DP。以 128 张 GPU 训练为例:

1
2
3
4
5
128 GPUs = TP_size=8 × PP_size=4 × DP_size=4

单机 8 卡 (NVLink): TP=8,张量并行在单机内
4 个流水线级: PP=4,模型按层切分到 4 组
4 路数据并行: DP=4,4 组 GPU 处理不同数据
1
2
3
4
5
6
7
8
                    数据并行组 0        数据并行组 1
PP Stage 0: [GPU 0-7] TP=8 [GPU 32-39] TP=8
↓ ↓
PP Stage 1: [GPU 8-15] TP=8 [GPU 40-47] TP=8
↓ ↓
PP Stage 2: [GPU 16-23] TP=8 [GPU 48-55] TP=8
↓ ↓
PP Stage 3: [GPU 24-31] TP=8 [GPU 56-63] TP=8

选择原则

  • TP:放在单机内(NVLink 带宽高),通信最密集
  • PP:可以跨机,通信量相对小(只传激活值)
  • DP:跨机,AllReduce 梯度

5. ZeRO 显存优化

ZeRO(Zero Redundancy Optimizer)是微软 DeepSpeed 提出的显存优化系列,核心思想:DDP 中每个 GPU 持有完整的模型参数、梯度、优化器状态是巨大的浪费,可以将它们切分到各 GPU 上。

白话理解:正常情况下每张卡都存一份完整的”仓库”(优化器状态、梯度、参数),ZeRO 的思路是大家分工,每人只保管一部分,需要时再问别人借——用通信换显存。就像合租的室友不必每人买一套完整的工具箱,大家各买几件,要用时互相借就行。

5.1 冗余分析

在标准 DDP 中,N 个 GPU 的总显存占用:

1
2
3
4
5
6
总浪费 = N 份 × (参数 + 梯度 + 优化器状态)
= N × 16P 字节

有用的 = 1 份 × 16P 字节

冗余率 = (N-1) / N ≈ 100%(N 较大时)

8 卡 DDP 训练 7B 模型:总占用 8 × 112GB = 896 GB,其中 784 GB 是冗余!

5.2 ZeRO-1:切分优化器状态

每个 GPU 只保存 1/N 的优化器状态(Adam 的 m 和 v):

1
2
3
4
5
6
                    GPU 0      GPU 1      GPU 2      GPU 3
模型参数: 完整 P 完整 P 完整 P 完整 P
梯度: 完整 G 完整 G 完整 G 完整 G
优化器 (Adam m,v): 1/4 1/4 1/4 1/4

显存: 4P + 4P + 8P/4 = 10P (对比 DDP 的 16P,节省 37.5%)

额外通信:更新参数后需要 AllGather 把更新后的参数广播给所有 GPU。

5.3 ZeRO-2:切分梯度 + 优化器状态

在 ZeRO-1 基础上,梯度也切分到各 GPU:

1
2
3
4
5
6
                    GPU 0      GPU 1      GPU 2      GPU 3
模型参数: 完整 P 完整 P 完整 P 完整 P
梯度: 1/4 1/4 1/4 1/4
优化器 (Adam m,v): 1/4 1/4 1/4 1/4

显存: 4P + 4P/4 + 8P/4 = 7P (对比 DDP 的 16P,节省 56%)

通信变化:反向传播中用 ReduceScatter 替代 AllReduce,每个 GPU 只得到自己负责的那部分梯度。

5.4 ZeRO-3:切分一切

参数、梯度、优化器状态全部切分:

1
2
3
4
5
6
                    GPU 0      GPU 1      GPU 2      GPU 3
模型参数: 1/4 1/4 1/4 1/4
梯度: 1/4 1/4 1/4 1/4
优化器 (Adam m,v): 1/4 1/4 1/4 1/4

显存: 4P/4 + 4P/4 + 8P/4 = 4P (对比 DDP 的 16P,节省 75%)

代价:前向和反向传播中都需要 AllGather 来临时收集完整的层参数,用完即释放。

1
2
3
4
5
前向传播某一层:
AllGather 收集该层完整参数 → 计算 → 释放非本 GPU 的参数

反向传播某一层:
AllGather 收集该层完整参数 → 计算梯度 → ReduceScatter 切分梯度 → 释放

5.5 ZeRO 三级对比

参数 梯度 优化器状态 每 GPU 显存 额外通信
DDP 完整 完整 完整 16P AllReduce (梯度)
ZeRO-1 完整 完整 1/N 4P + 4P + 8P/N + AllGather (参数更新)
ZeRO-2 完整 1/N 1/N 4P + 12P/N ReduceScatter + AllGather
ZeRO-3 1/N 1/N 1/N 16P/N 前向/反向各一次 AllGather

通信量对比(每个训练步):

方案 通信量 (每 GPU) 通信次数
DDP 2P 1 (AllReduce)
ZeRO-1 2P + P = 3P AllReduce + AllGather
ZeRO-2 2P + P = 3P ReduceScatter + AllGather
ZeRO-3 2P + 3P = 5P 多次 AllGather + ReduceScatter

ZeRO-3 通信量是 DDP 的 ~2.5 倍,但显存节省巨大。在 GPU 数量很多时(如 64+),显存节省远比通信开销增加更有价值。


6. 混合精度训练

6.1 为什么用混合精度

  • 节省显存:FP16 参数比 FP32 小一半
  • 加速计算:Tensor Core 的 FP16/BF16 算力是 FP32 的 ~15 倍(H100)
  • 减少通信量:梯度同步的数据量减半

6.2 FP16 vs BF16

FP16 BF16
符号位 1 1
指数位 5 8
尾数位 10 7
表示范围 ±65504 ±3.4×10^38(与 FP32 相同)
精度 更高(尾数位多) 更低
溢出风险 高(范围小) 低(范围大)
Tensor Core 支持 Volta+ Ampere+

关键区别

  • FP16 容易溢出:指数位只有 5 位,表示范围小,梯度容易超出范围(gradient overflow)
  • BF16 几乎不溢出:指数位与 FP32 相同,表示范围一致
  • BF16 精度略低:尾数位少 3 位,但实践中对训练影响不大

结论:如果硬件支持(Ampere+),优先使用 BF16,不需要 Loss Scaling。

6.3 混合精度训练流程

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
     ┌─────── FP32 主副本(Master Weights)───────┐
│ │
↓ │
转为 FP16/BF16 参数 │
│ │
前向传播 (FP16/BF16) │
│ │
计算 loss │
│ │
Loss Scaling(仅 FP16 需要) │
│ │
反向传播 → FP16/BF16 梯度 │
│ │
梯度 Unscaling + 转为 FP32 │
│ │
FP32 优化器更新 FP32 主副本 ──────────────────────┘

Loss Scaling(损失缩放)

FP16 训练中,梯度值可能非常小(如 1e-7),超出 FP16 的表示精度变为 0(gradient underflow)。Loss Scaling 通过放大 loss 来放大梯度,更新时再缩回来:

白话理解:FP16 的数值范围小,梯度太小会变成 0(下溢)。Loss Scaling 先把 loss 放大,算完梯度再缩回来——就像用放大镜看蚂蚁,看清楚细节后再缩回原始比例记录。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# 手动 Loss Scaling
loss = model(inputs)
scaled_loss = loss * scale_factor # 放大 loss
scaled_loss.backward() # 梯度也被放大
optimizer.step(scale=1/scale_factor) # 更新时缩回

# PyTorch 自动 Loss Scaling(推荐)
scaler = torch.cuda.amp.GradScaler() # 自动调整 scale factor

with torch.cuda.amp.autocast(): # 自动混合精度上下文
output = model(inputs)
loss = criterion(output, labels)

scaler.scale(loss).backward() # 自动缩放
scaler.step(optimizer) # 自动 unscale + 更新
scaler.update() # 自动调整 scale factor

BF16 不需要 Loss Scaling,因为其表示范围与 FP32 相同,梯度不会 underflow。

6.4 PyTorch 混合精度实战

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
from torch.cuda.amp import autocast, GradScaler

model = MyModel().cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

# BF16 训练(Ampere+ GPU,推荐)
for batch in dataloader:
inputs, labels = batch
inputs, labels = inputs.cuda(), labels.cuda()

with autocast(dtype=torch.bfloat16): # 自动将支持的操作转为 BF16
outputs = model(inputs)
loss = criterion(outputs, labels)

loss.backward() # BF16 不需要 GradScaler
optimizer.step()
optimizer.zero_grad()

# FP16 训练(需要 Loss Scaling)
scaler = GradScaler()
for batch in dataloader:
inputs, labels = batch
inputs, labels = inputs.cuda(), labels.cuda()

with autocast(dtype=torch.float16):
outputs = model(inputs)
loss = criterion(outputs, labels)

scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()

7. 其他显存优化技术

7.1 梯度累积(Gradient Accumulation)

在显存不足以跑大 batch 时,可以将一个大 batch 拆成多个小 micro-batch,累积梯度后再更新:

白话理解:显存不够一次算大 batch?那就分几次算小 batch,把梯度攒起来再一起更新——就像攒零钱换整钱,每次攒一点,攒够了再花。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
accumulation_steps = 4  # 等效 batch_size = micro_batch_size * 4

for i, (inputs, labels) in enumerate(dataloader):
inputs, labels = inputs.cuda(), labels.cuda()

with autocast(dtype=torch.bfloat16):
outputs = model(inputs)
loss = criterion(outputs, labels)
loss = loss / accumulation_steps # 归一化

loss.backward() # 梯度累积(不清零)

if (i + 1) % accumulation_steps == 0:
optimizer.step() # 累积够了才更新
optimizer.zero_grad() # 更新后清零

效果

  • 显存占用 = 小 micro-batch 的显存
  • 训练效果 ≈ 大 batch(梯度累积等价于大 batch 求平均)
  • 代价:训练步骤变慢(每个有效 step 需要多次前向/反向)

7.2 Activation Checkpointing(激活重计算)

正常训练中,前向传播的所有中间激活值都要保存到反向传播时用来算梯度。Activation Checkpointing 的思路:不保存中间激活值,反向传播时重新计算。

1
2
3
4
5
6
7
8
9
10
正常训练:
前向: x → Layer1 → a1 (保存) → Layer2 → a2 (保存) → ... → loss
反向: 用保存的 a1, a2, ... 计算梯度
显存: O(n_layers)

Activation Checkpointing:
前向: x → Layer1 → (丢弃 a1) → Layer2 → (丢弃 a2) → ... → loss
反向: 重新前向计算 a1, a2, ... → 计算梯度
显存: O(sqrt(n_layers)) 或 O(1)
代价: 增加约 33% 的计算量
1
2
3
4
5
6
7
8
9
10
# PyTorch 原生 Activation Checkpointing
from torch.utils.checkpoint import checkpoint

class TransformerBlock(nn.Module):
def forward(self, x):
# 正常前向
return self.ffn(self.attention(x))

# 使用 checkpoint 包装
output = checkpoint(transformer_block, input, use_reentrant=False)

7.3 显存优化技术对比

技术 原理 节省 代价
混合精度 (FP16/BF16) 低精度存储和计算 参数/激活显存减半 精度风险(FP16 需 Loss Scaling)
梯度累积 小 batch 累积梯度 激活值显存降低 训练步骤变慢
Activation Checkpointing 丢弃激活值,反向时重算 激活值显存降到 O(1) ~33% 额外计算
ZeRO-1 切分优化器状态 优化器显存 /N 少量额外通信
ZeRO-2 切分梯度 + 优化器 梯度+优化器 /N 额外通信
ZeRO-3 切分一切 总显存 /N 通信量 x2.5

8. 训练框架实战

8.1 DeepSpeed

DeepSpeed 是微软开发的分布式训练库,核心能力是 ZeRO 系列。使用非常简单——只需要一个 JSON 配置文件和少量代码修改。

DeepSpeed ZeRO-2 配置

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
{
"train_batch_size": 64,
"gradient_accumulation_steps": 4,
"fp16": {
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 1000
},
"zero_optimization": {
"stage": 2,
"allgather_partitions": true,
"reduce_scatter": true,
"overlap_comm": true,
"contiguous_gradients": true
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": 1e-4,
"betas": [0.9, 0.999],
"weight_decay": 0.01
}
}
}

DeepSpeed ZeRO-3 配置

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
{
"train_batch_size": 64,
"gradient_accumulation_steps": 8,
"bf16": {
"enabled": true
},
"zero_optimization": {
"stage": 3,
"offload_param": {
"device": "none"
},
"offload_optimizer": {
"device": "none"
},
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"stage3_prefetch_bucket_size": 5e8,
"stage3_param_persistence_threshold": 1e6
}
}

代码集成

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import deepspeed

# 初始化
model = MyModel()
model_engine, optimizer, _, _ = deepspeed.initialize(
model=model,
config="ds_config.json",
model_parameters=model.parameters()
)

# 训练循环(几乎和普通训练一样)
for batch in dataloader:
inputs, labels = batch
inputs = inputs.to(model_engine.device)
labels = labels.to(model_engine.device)

outputs = model_engine(inputs)
loss = criterion(outputs, labels)

model_engine.backward(loss) # 替代 loss.backward()
model_engine.step() # 替代 optimizer.step()

启动方式

1
2
3
4
5
# 单机 4 卡
deepspeed --num_gpus=4 train.py --deepspeed_config ds_config.json

# 多机(hostfile 格式: hostname slots=gpu_count)
deepspeed --hostfile hostfile train.py --deepspeed_config ds_config.json

8.2 Megatron-LM

Megatron-LM 是 NVIDIA 开发的大模型训练框架,核心贡献是 张量并行(TP)和流水线并行(PP) 的高效实现。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Megatron-LM 典型启动命令
python -m torch.distributed.launch \
--nproc_per_node=8 \
--nnodes=4 \
pretrain_gpt.py \
--tensor-model-parallel-size 8 \
--pipeline-model-parallel-size 4 \
--num-layers 80 \
--hidden-size 8192 \
--num-attention-heads 64 \
--micro-batch-size 1 \
--global-batch-size 1024 \
--seq-length 4096 \
--train-iters 500000 \
--lr 1.5e-4 \
--bf16

8.3 HuggingFace Accelerate + DeepSpeed

对于使用 HuggingFace 生态的用户,Accelerate 提供了最简单的集成方式:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from accelerate import Accelerator

accelerator = Accelerator() # 自动检测分布式环境

model = MyModel()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
dataloader = DataLoader(dataset, batch_size=32)

# 一行代码包装
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)

for batch in dataloader:
outputs = model(batch['input_ids'])
loss = outputs.loss
accelerator.backward(loss) # 替代 loss.backward()
optimizer.step()
optimizer.zero_grad()

配置文件通过 accelerate config 交互式生成,支持 DDP / FSDP / DeepSpeed。


9. 并行策略选型指南

9.1 决策流程

1
2
3
4
5
6
7
8
9
10
11
12
13
模型能放进单卡吗?
├── 能 → DDP 就够了

└── 不能 → 显存缺口有多大?

├── 差一点(1-2x)→ ZeRO-2 + 混合精度 + Activation Checkpointing

├── 差不少(2-8x)→ ZeRO-3 或 FSDP

└── 差很多(>8x,大模型)→ 3D 并行
├── TP=机内卡数(通常 8)
├── PP=模型层数 / 每组层数
└── DP=剩余 GPU 数

9.2 常见配置参考

模型规模 GPU 数量 推荐方案
1-7B 1-8 DDP + BF16 + Activation Checkpointing
7-13B 4-16 ZeRO-2/3 + BF16
13-70B 16-64 ZeRO-3 或 TP=8 + PP=2-4 + DP
70B+ 64-256+ TP=8 + PP=4-16 + DP + ZeRO-1
MoE (>100B) 64-512+ TP + PP + EP(Expert Parallel)+ DP

9.3 权衡思维

分布式训练没有银弹。每种技术都是在 “计算、通信、显存” 三角中做取舍:

技术 牺牲 换取
ZeRO-3 通信量 x2.5 显存 /N
Activation Checkpointing 33% 额外计算 激活值显存降至 O(1)
流水线并行 流水线气泡(空闲时间) 模型跨机分布
张量并行 高频 AllReduce 通信 单层参数跨卡分布
混合精度 精度风险 显存减半 + 算力翻倍
梯度累积 训练步骤变慢 等效大 batch

10. 自我检验清单

完成本文学习后,你应该能够:

  • 能用 PyTorch DDP 将单卡训练脚本改造为多卡分布式训练(含 torchrun 启动)
  • 能用公式计算给定模型的参数量、优化器状态、梯度所需的显存占用(16P 公式)
  • 能解释 ZeRO-1/2/3 各切分了什么(优化器状态/梯度/参数),通信量如何变化
  • 能画出 TP + PP 的 3D 并行拓扑图,标注通信位置与通信量
  • 能解释张量并行中列并行和行并行的切分方式及它们如何配对减少通信
  • 能解释流水线并行的气泡问题以及 micro-batch 如何缓解
  • 能解释混合精度训练为什么需要 Loss Scaling,以及 BF16 vs FP16 的差异
  • 能配置 DeepSpeed ZeRO Stage 2/3 并跑通一个训练任务
  • 能解释梯度累积如何在有限显存下模拟更大 Batch Size
  • 能解释 MHA、MQA、GQA 的区别对 KV Cache 大小的影响
  • 能根据模型规模和 GPU 数量给出合理的并行策略方案

📚 参考资料