2.4 Transformer前馈网络FFN深入理解
前馈网络(FFN)是 Transformer 中参数量最大的模块,占据单个 Block 约 2/3 的参数。它负责对每个 token 进行独立的非线性变换,是模型”记忆知识”和”深度推理”的核心载体。本文从 FFN 的结构设计出发,深入剖析激活函数演进、参数量计算、SwiGLU 门控机制,并延伸到张量并行切分、MoE 专家并行和 CUDA kernel 融合等 AI Infra 工程实践。
📑 目录
- 1. FFN 在 Transformer 中的角色定位
- 2. FFN 的”展开-压缩”结构
- 3. 激活函数详解
- 4. SwiGLU:当前大模型的主流选择
- 5. 参数量计算与对比
- 6. FFN 中间维度的选择:为什么是 4x 和 8/3x
- 7. 张量并行切分 FFN
- 8. MoE:将 FFN 拆分为多个专家
- 9. CUDA Kernel 融合在 FFN 中的应用
- 自我检验清单
- 参考资料
1. FFN 在 Transformer 中的角色定位
要理解 FFN 的作用,先回顾一下 Transformer Block 的整体结构。每个 Block 包含两个核心子模块:Self-Attention 和 FFN,两者通过残差连接和 LayerNorm 串联起来。
Self-Attention 的职责是信息交互——让序列中的每个 token 能够”看见”其他 token 并汇聚相关信息。打个比方,如果把一个 Transformer Block 看作一家工厂的生产流水线,Self-Attention 就是流水线上的”原料分拣站”:它根据订单需求(Query)从仓库中各个货架(Key)上挑选相关原料(Value),把需要的材料汇集到一起。
而 FFN 则是紧随其后的”加工车间”。原料分拣完毕后,需要在加工车间里对每份原料进行独立的深加工——切割、熔炼、塑形——把粗加工的信息变成精细的、有结构的产出。关键在于:加工车间里的每条生产线是独立运作的,每个 token 各走各的生产线,彼此之间不再交互。
用更技术性的语言来说:
- Self-Attention 是 token 间的运算——输出依赖于序列中所有 token 的组合关系
- FFN 是 token 内的运算——对每个 token 的表示向量独立施加相同的非线性变换
这种”先交互、后加工”的两步设计有深层次的意义。仅靠线性的信息汇聚(Attention 本质上是加权求和,是一种线性运算),模型的表达能力会受到严重限制。FFN 引入的非线性变换赋予了模型逼近任意复杂函数的能力——根据万能近似定理(Universal Approximation Theorem),一个带有非线性激活函数的两层前馈网络可以逼近任意连续函数。
近年来的研究还发现,FFN 在大模型中扮演着”知识库”的角色。大量的事实性知识(”巴黎是法国的首都”、”水的分子式是 H2O”)被编码在 FFN 的参数中,而 Attention 层更多负责组织和提取这些知识。这也解释了为什么 FFN 占据了模型大部分的参数量——它需要足够大的”存储容量”来记忆海量知识。
2. FFN 的”展开-压缩”结构
2.1 基本结构
标准 FFN 的数学表达式非常简洁:
$$
\text{FFN}(x) = W_2 \cdot \sigma(W_1 \cdot x + b_1) + b_2
$$
其中:
- $x$ 是输入向量,维度为 $d_{model}$
- $W_1$ 的形状为 $(d_{model}, d_{ff})$,将维度从 $d_{model}$ 升高到 $d_{ff}$
- $\sigma$ 是非线性激活函数
- $W_2$ 的形状为 $(d_{ff}, d_{model})$,将维度从 $d_{ff}$ 压缩回 $d_{model}$
- $b_1, b_2$ 是偏置项(现代大模型通常省略偏置)
以 $d_{model} = 4096$ 为例,如果 $d_{ff} = 4 \times d_{model} = 16384$,数据流的维度变化为:
| 阶段 | 维度 | 说明 |
|---|---|---|
| 输入 | $(N, 4096)$ | 原始表示空间 |
| 升维 | $(N, 16384)$ | 高维展开空间 |
| 激活 | $(N, 16384)$ | 非线性变换 |
| 降维 | $(N, 4096)$ | 压缩回原始维度 |
2.2 为什么采用”先升维后降维”的瓶颈结构
这个”展开-压缩”的设计模式并非随意之举,背后有清晰的信息处理逻辑。
高维空间中线性可分性更好。 这是统计学习中的经典原理。想象你有一堆红球和蓝球混在一条线上(一维),很难找到一个点把它们完美分开。但如果你把这些球抛到一个三维空间中(升维),它们就更容易被一个平面分开。FFN 的第一层 $W_1$ 就是在做这件事——把信息投射到一个更高维的空间,让不同的特征模式在高维空间中更容易被区分和操作。
激活函数在高维空间中更有效。 非线性激活函数的作用是选择性地保留或抑制某些信息通道。在高维空间中,每个维度可以被看作一个独立的”特征检测器”,激活函数决定哪些检测器被激活、哪些被关闭。维度越高,可以同时检测的特征模式越多,模型的表达能力就越强。
降维是信息蒸馏。 第二层 $W_2$ 将高维空间中被激活函数筛选过的信息压缩回原始维度。这个过程迫使模型只保留最重要的信息——就像把一锅汤熬浓,去掉多余的水分,留下精华。
与 Autoencoder 的类比。 如果你熟悉自编码器(Autoencoder),FFN 的结构正好是反过来的。Autoencoder 先压缩再还原(编码-解码),用来学习数据的低维表示;FFN 先展开再压缩(展开-蒸馏),用来在高维空间中进行复杂的非线性变换后提取有用信息。
2.3 PyTorch 实现:标准 FFN
下面是一个带注释的标准 FFN 实现:
1 | import torch |
3. 激活函数详解
激活函数是 FFN 中至关重要的组件。没有激活函数,两层线性变换的复合仍然是线性变换($W_2 \cdot W_1 \cdot x$ 等价于一个 $(W_2 W_1) \cdot x$ 的单层线性变换),FFN 就失去了存在的意义。激活函数引入的非线性是模型表达能力的关键来源。
3.1 ReLU(Rectified Linear Unit)
数学公式
$$
\text{ReLU}(x) = \max(0, x)
$$
直觉解释
ReLU 是最朴素的”开关”函数——输入为正就原样通过,输入为负就直接关闭。你可以把它想象成一个单向阀门:水(信号)只能从正方向流过,负方向完全截断。
输出范围和特性
- 输出范围:$[0, +\infty)$
- 正半轴梯度恒为 1,不存在梯度衰减问题
- 负半轴梯度恒为 0,输出恒为 0
- 计算极其高效:只需一次比较操作
优点
- 计算简单,硬件友好——在 GPU 上只需一个 max 操作
- 正半轴梯度为 1,缓解了深层网络的梯度消失问题(相比 sigmoid/tanh)
- 引入稀疏性:负值被置零,激活后的向量通常有大量零元素,有利于特征的稀疏表示
缺点:神经元死亡问题(Dying ReLU)
这是 ReLU 最严重的缺陷,值得详细讨论。
当某个神经元的输入在训练过程中恰好落入负区间,ReLU 的输出为 0,梯度也为 0。零梯度意味着该神经元的权重在反向传播时得不到任何更新信号。而权重没有更新,下一轮前向传播时输入大概率仍然为负——于是这个神经元永久性地”死亡”了,不再对模型的计算做出任何贡献。
这个问题在学习率较大时尤为严重。一次大幅度的梯度更新可能把权重推到一个”坏”的区域,使得该神经元对所有训练样本的输出都为负,从此一蹶不振。在实际训练中,有时候会观察到超过 40% 的 ReLU 神经元处于”死亡”状态,这意味着模型有效容量大幅缩水。
LeakyReLU 通过给负半轴一个微小的斜率(如 0.01x)来缓解这个问题,但这只是打了个补丁,并没有从根本上解决激活函数在零点处不光滑、不可微的问题。
3.2 GELU(Gaussian Error Linear Unit)
数学公式
$$
\text{GELU}(x) = x \cdot \Phi(x) = x \cdot \frac{1}{2}\left[1 + \text{erf}\left(\frac{x}{\sqrt{2}}\right)\right]
$$
其中 $\Phi(x)$ 是标准正态分布的累积分布函数(CDF),$\text{erf}$ 是误差函数。
实际计算中常用近似公式:
$$
\text{GELU}(x) \approx 0.5x\left[1 + \tanh\left(\sqrt{\frac{2}{\pi}}(x + 0.044715x^3)\right)\right]
$$
直觉解释:概率化门控
GELU 的核心思想是把激活看作一种”概率性门控”。对于输入 $x$,GELU 计算的是”以概率 $\Phi(x)$ 保留 $x$,以概率 $1-\Phi(x)$ 将其置零”的期望值。
为什么用正态分布的 CDF 作为保留概率?这基于一个直觉假设:如果我们认为神经网络中的输入信号近似服从正态分布(根据中心极限定理,多个随机变量之和趋向正态分布),那么一个输入值越大(越偏离均值向右),它越可能是”有意义的信号”而非噪声,就越应该被保留。$\Phi(x)$ 恰好提供了这样一个随输入值单调递增的保留概率。
与 ReLU 的硬切换(0 或全通过)相比,GELU 实现了一种”软门控”——接近零的输入不会被完全截断,而是按照一个与其大小成正比的概率被保留。
输出范围和特性
- 输出范围:约 $[-0.17, +\infty)$——负半轴有一个很小的负值区域
- 处处光滑可微,没有 ReLU 在零点处的不可微问题
- 负值区域输出非零但极小,梯度也非零,不存在神经元死亡问题
- 近似于 ReLU 但更平滑:大正值时 GELU(x) 趋近 x,大负值时趋近 0
优缺点
优点:
- 平滑可微,训练更稳定
- 概率化门控使信息保留更加精细
- 在 BERT、GPT-2 等模型中被验证效果优于 ReLU
缺点:
- 计算量比 ReLU 大(涉及 erf 函数或 tanh 近似)
- 但在现代 GPU 上,计算开销差异相对于 GEMM 而言几乎可忽略
3.3 Swish
数学公式
$$
\text{Swish}(x) = x \cdot \sigma(\beta x) = \frac{x}{1 + e^{-\beta x}}
$$
其中 $\sigma$ 是 sigmoid 函数,$\beta$ 是一个可学习或固定的参数。当 $\beta=1$ 时,通常简写为 $\text{SiLU}(x) = x \cdot \sigma(x)$。
直觉解释
Swish 可以看作是 ReLU 和线性函数之间的一种自适应插值。sigmoid 函数 $\sigma(x)$ 输出一个 0 到 1 之间的”门控值”,乘以 $x$ 本身,就实现了”用输入自身来控制信息通过量”的效果——这就是所谓的”自门控”(self-gating)。
当 $x$ 为大正值时,$\sigma(x)$ 趋近 1,$\text{Swish}(x) \approx x$,行为接近线性;当 $x$ 为大负值时,$\sigma(x)$ 趋近 0,$\text{Swish}(x) \approx 0$,行为接近截断。有趣的是,当 $x$ 在零附近的小负值时,Swish 的输出会略低于零——它有一个小的”凹坑”,最小值约为 $-0.278$(在 $x \approx -1.278$ 处取到)。
输出范围和特性
- 输出范围:约 $[-0.278, +\infty)$
- 处处光滑可微
- 非单调函数——在负半轴有一个小的下凹区域
- 当 $\beta \to \infty$ 时,Swish 退化为 ReLU;当 $\beta = 0$ 时,退化为线性函数 $x/2$
平滑性优势
Swish 的平滑性在训练中带来了实质性的好处。ReLU 在 $x=0$ 处梯度不连续(左导数为 0,右导数为 1),这意味着优化过程中的损失曲面(loss landscape)存在”棱角”。梯度下降算法在这些不光滑的点附近容易产生震荡。
Swish 的处处可微性质使损失曲面更加光滑,梯度信号在整个定义域上连续变化,优化过程更加稳定。Google Brain 的研究表明,Swish 在深层网络中的表现一致优于 ReLU,尤其是在网络深度超过 40 层时优势更为明显。
3.4 对比总结
| 特性 | ReLU | GELU | Swish/SiLU |
|---|---|---|---|
| 公式 | $\max(0,x)$ | $x \cdot \Phi(x)$ | $x \cdot \sigma(x)$ |
| 输出范围 | $[0,+\infty)$ | $\approx[-0.17,+\infty)$ | $\approx[-0.278,+\infty)$ |
| 零点可微 | 否 | 是 | 是 |
| 单调性 | 是 | 否(近似单调) | 否 |
| 负值输出 | 无 | 极小 | 小 |
| 神经元死亡 | 有 | 无 | 无 |
| 计算成本 | 最低 | 中 | 中 |
| 典型应用 | 原始 Transformer | BERT, GPT-2 | LLaMA (SwiGLU 中) |
4. SwiGLU:当前大模型的主流选择
4.1 GLU 门控机制的思想
在讲 SwiGLU 之前,先理解 GLU(Gated Linear Unit,门控线性单元)的思想。
标准 FFN 对输入施加的是”统一的”激活函数——所有维度经过相同的非线性变换。GLU 提出了一种不同的思路:让模型自己学习”哪些信息该通过、哪些该被抑制”。
GLU 的核心公式是:
$$
\text{GLU}(x) = (W_{up} \cdot x) \otimes \sigma(W_{gate} \cdot x)
$$
其中 $\otimes$ 表示逐元素相乘,$\sigma$ 是 sigmoid 函数。
这里引入了两个独立的线性投影:
- $W_{up} \cdot x$:生成”候选内容”——这些信息准备通过
- $\sigma(W_{gate} \cdot x)$:生成”门控信号”——每个维度的通过概率(0 到 1 之间)
两者逐元素相乘,门控信号大的维度信息被保留,门控信号小的维度信息被抑制。这比统一施加一个激活函数要灵活得多——模型可以针对不同的输入内容,学习不同的门控策略。
4.2 SwiGLU 的完整定义
SwiGLU 是 Noam Shazeer 在 2020 年的论文”GLU Variants Improve Transformer”中提出的 GLU 变体,将门控中的 sigmoid 替换为 Swish 函数:
$$
\text{SwiGLU}(x) = (\text{Swish}(W_{gate} \cdot x)) \otimes (W_{up} \cdot x)
$$
完整的 SwiGLU FFN 包含三个权重矩阵:
$$
\text{FFN}{SwiGLU}(x) = W{down} \cdot \left[\text{Swish}(W_{gate} \cdot x) \otimes (W_{up} \cdot x)\right]
$$
其中:
- $W_{gate}$:形状 $(d_{model}, d_{ff})$,门控投影,其输出经过 Swish 激活后产生门控信号
- $W_{up}$:形状 $(d_{model}, d_{ff})$,内容投影,生成候选通过的信息
- $W_{down}$:形状 $(d_{ff}, d_{model})$,降维投影,将结果压缩回原始维度
4.3 为什么门控机制有效
门控机制的有效性可以从几个角度理解。
更精细的信息筛选。 标准 FFN 中,激活函数对升维后的每个维度施加相同的变换规则——一个全局的非线性函数。SwiGLU 的门控机制则允许模型对每个维度施加不同的”通过/抑制”决策。这意味着模型可以学到更加条件化的计算——“当输入具有特征 A 时,保留这些维度;当输入具有特征 B 时,保留那些维度。”
更丰富的梯度信号。 由于门控路径和内容路径是两个独立的线性投影,反向传播时梯度可以沿着两条路径分别回传,参数的更新信号更加丰富。这可以理解为一种隐式的”梯度高速公路”——与残差连接的思想异曲同工。
实验验证。 Shazeer 的论文在相同参数量的条件下对比了多种 GLU 变体,SwiGLU 在多个下游任务上一致地取得了最佳表现。后来 LLaMA、Mistral、Qwen、DeepSeek 等主流大模型纷纷采用 SwiGLU,在工业实践中进一步验证了其有效性。
4.4 PyTorch 实现:SwiGLU FFN
1 | import torch |
5. 参数量计算与对比
FFN 的参数量在整个 Transformer Block 中占据主导地位。准确计算参数量是显存规划和并行策略设计的基础。
5.1 标准 FFN vs SwiGLU FFN 参数量对比
以 $d_{model} = 4096$ 为例,分别计算两种 FFN 结构的参数量。
标准 FFN($d_{ff} = 4 \times d_{model} = 16384$)
| 参数矩阵 | 形状 | 参数量 |
|---|---|---|
| $W_{up}$ | (4096, 16384) | 67,108,864 |
| $W_{down}$ | (16384, 4096) | 67,108,864 |
| 合计 | 134,217,728 (134M) |
SwiGLU FFN($d_{ff} = \frac{8}{3} \times d_{model} \approx 10922$,实际取 11008)
| 参数矩阵 | 形状 | 参数量 |
|---|---|---|
| $W_{gate}$ | (4096, 11008) | 45,088,768 |
| $W_{up}$ | (4096, 11008) | 45,088,768 |
| $W_{down}$ | (11008, 4096) | 45,088,768 |
| 合计 | 135,266,304 (135M) |
5.2 对比分析
| 对比维度 | 标准 FFN | SwiGLU FFN |
|---|---|---|
| 权重矩阵数量 | 2 | 3 |
| 中间维度 | $4 \times d_{model}$ | $\frac{8}{3} \times d_{model}$ |
| 总参数量 | $8 \times d_{model}^2$ | $3 \times d_{model} \times \frac{8}{3} d_{model} = 8 \times d_{model}^2$ |
| 以 $d_{model}=4096$ 计算 | 134M | 135M |
| GEMM 次数(前向) | 2 | 3 |
| 激活函数 | 全局施加 | 门控选择性施加 |
关键发现:两种结构的总参数量几乎相同($8 \times d_{model}^2$),但 SwiGLU 通过门控机制实现了更好的效果。代价是前向传播需要 3 次 GEMM 而非 2 次,计算量略有增加。
5.3 与 Attention 模块的参数量对比
| 模块 | 参数量 | 占 Block 比例 |
|---|---|---|
| Self-Attention ($W_Q, W_K, W_V, W_O$) | $4 \times d_{model}^2 = 67M$ | ~33% |
| FFN(标准或 SwiGLU) | $8 \times d_{model}^2 = 135M$ | ~67% |
| LayerNorm $\times 2$ | $4 \times d_{model} = 16K$ | <0.01% |
| 单 Block 合计 | ~201M |
FFN 的参数量约为 Attention 的 2 倍,占单个 Block 参数总量的约 2/3。这意味着任何针对 FFN 的优化(并行切分、量化、专家化)都将对整体效率产生显著影响。
6. FFN 中间维度的选择:为什么是 4x 和 8/3x
6.1 标准 FFN 为什么选择 4 倍扩展
$d_{ff} = 4 \times d_{model}$ 的选择来自原始 Transformer 论文”Attention Is All You Need”。Vaswani 等人设定 $d_{model}=512, d_{ff}=2048$,形成了 4 倍扩展比。这个比例后来成为了事实上的标准。
为什么是 4 倍而不是 2 倍或 8 倍?这在很大程度上是经验性的选择。从直觉上说:
- 太小(如 2x):高维空间不够大,非线性变换的表达能力受限,模型效果下降
- 太大(如 8x):参数量急剧膨胀,但效果提升边际递减,参数效率降低
- 4x 是一个平衡点:在模型效果和参数效率之间取得较好的折中
后续研究(如 Kaplan 等人关于 Scaling Laws 的工作)也表明,在固定总参数量的前提下,$d_{ff} / d_{model}$ 的比值在 4 附近是近似最优的。这个比值过大或过小都会导致模型在同等参数量下表现变差。
6.2 SwiGLU 为什么用 8/3 倍扩展
SwiGLU 将标准 FFN 的 2 个矩阵变成了 3 个矩阵。如果继续使用 $d_{ff} = 4 \times d_{model}$,总参数量会变成:
$$
3 \times d_{model} \times 4d_{model} = 12 \times d_{model}^2
$$
这比标准 FFN 的 $8 \times d_{model}^2$ 多了 50% 的参数。为了在引入门控机制的同时保持总参数量不变,需要调整中间维度:
$$
3 \times d_{model} \times d_{ff} = 8 \times d_{model}^2
$$
$$
d_{ff} = \frac{8}{3} \times d_{model} \approx 2.667 \times d_{model}
$$
以 $d_{model} = 4096$ 计算:$d_{ff} = \frac{8}{3} \times 4096 = 10922.67$。
在工程实现中,$d_{ff}$ 通常会被取整到特定数值的倍数(如 128 或 256 的倍数),以确保矩阵维度能够被 GPU 的 Tensor Core 高效处理。Tensor Core 在执行矩阵乘法时要求维度满足特定对齐要求(例如 FP16 下通常是 8 的倍数,实践中对齐到 128 或 256 可以获得更好的利用率)。
以 LLaMA-2-7B 为例:$10922.67$ 取整到 $11008$($= 43 \times 256$),这就是我们看到的 LLaMA 配置中 $d_{ff} = 11008$ 的由来。
7. 张量并行切分 FFN
当模型规模大到单张 GPU 放不下时,需要将模型参数切分到多张 GPU 上——这就是张量并行(Tensor Parallelism, TP)。FFN 占据 Block 参数的 2/3,如何高效切分 FFN 直接决定了张量并行的整体效率。
7.1 Megatron-LM 的列切分 + 行切分策略
Megatron-LM 提出了一种精巧的切分方案,核心思想是:**$W_{up}$(和 $W_{gate}$)按列切分,$W_{down}$ 按行切分**,使得中间结果不需要通信,只在最终输出时做一次 AllReduce。
以 2 张 GPU 为例,标准 FFN 的切分如下:
第一步:$W_{up}$ 列切分
将 $W_{up}$ 的列(输出维度)均匀分配给两张 GPU:
1 | GPU 0: W_up[:, :d_ff/2] 形状 (d_model, d_ff/2) |
每张 GPU 各自计算自己那部分的升维结果:
1 | GPU 0: h_0 = activation(x @ W_up_0) 形状 (N, d_ff/2) |
注意:这里 $x$ 需要在两张 GPU 上各有一份完整拷贝(通过前一步的 AllReduce 或广播获得)。列切分的关键优势在于:激活函数可以在切分后独立施加——因为激活函数是逐元素操作,不依赖其他维度的值,所以对完整向量施加激活等价于对各部分分别施加激活。
第二步:$W_{down}$ 行切分
将 $W_{down}$ 的行(输入维度)均匀分配:
1 | GPU 0: W_down[:d_ff/2, :] 形状 (d_ff/2, d_model) |
每张 GPU 用自己的中间结果乘以自己的 $W_{down}$ 部分:
1 | GPU 0: out_0 = h_0 @ W_down_0 形状 (N, d_model) |
第三步:AllReduce 求和
最终输出等于两张 GPU 的部分结果之和:
1 | output = AllReduce(out_0, out_1) = out_0 + out_1 形状 (N, d_model) |
这是因为:
$$
W_{down} \cdot h = \begin{bmatrix} W_{down,0} \ W_{down,1} \end{bmatrix}^T \cdot \begin{bmatrix} h_0 \ h_1 \end{bmatrix} = W_{down,0}^T \cdot h_0 + W_{down,1}^T \cdot h_1
$$
7.2 SwiGLU 的切分
对于 SwiGLU FFN,$W_{gate}$ 和 $W_{up}$ 都按列切分,$W_{down}$ 按行切分。逻辑完全一致:
1 | GPU 0: gate_0 = Swish(x @ W_gate_0) # (N, d_ff/2) |
整个 FFN 在前向传播中只需要一次 AllReduce。而每张 GPU 上的三次 GEMM(gate/up/down)和逐元素操作都是本地计算,不需要通信。这使得张量并行在 FFN 上的通信开销非常低——每个 Block 的 FFN 只贡献一次 AllReduce(Attention 也贡献一次,所以每个 Block 共两次 AllReduce)。
7.3 工程考量
在实际部署中,$d_{ff}$ 需要能被 TP 的 GPU 数量整除。例如,$d_{ff} = 11008$ 在 TP=2 时切分为每卡 5504,TP=4 时每卡 2752,TP=8 时每卡 1376——都能整除。这也是 $d_{ff}$ 选取时会优先考虑 2 的幂次或者 128/256 倍数的原因之一。
8. MoE:将 FFN 拆分为多个专家
8.1 核心思想
混合专家模型(Mixture of Experts, MoE)是一种扩大模型容量而不等比例增加计算量的技术,其核心改造对象正是 FFN 层。
基本思路非常直观:与其使用一个巨大的 FFN,不如将其拆分为多个较小的”专家”(Expert),每个专家都是一个独立的 FFN。对于每个输入 token,只激活其中少数几个专家进行计算,其余专家保持”休眠”。
例如,一个 MoE 层可能包含 64 个专家 FFN,但每个 token 只选择其中 2 个专家(Top-2)进行前向计算。这意味着:
- 模型总参数量 = 64 个专家的参数量之和(很大)
- 每个 token 的计算量 = 2 个专家的计算量(很小)
这就实现了”参数量大但计算量小”的目标——模型有足够的容量存储知识,但推理和训练时的计算成本可控。
8.2 路由机制(Router)
MoE 的关键组件是路由器(Router / Gate),它决定每个 token 应该被发送到哪些专家。
最常见的路由方式是 Top-K 路由:
$$
g(x) = \text{softmax}(W_r \cdot x)
$$
$$
\text{TopK}(g(x)) \to \text{选出概率最大的 K 个专家}
$$
其中 $W_r$ 是路由器的权重矩阵,形状为 $(d_{model}, n_{experts})$。路由器将每个 token 的表示向量映射到一个 $n_{experts}$ 维的概率分布,概率最大的 K 个专家被选中。
最终输出是被选中专家输出的加权和:
$$
\text{MoE}(x) = \sum_{i \in \text{TopK}} g_i(x) \cdot \text{Expert}_i(x)
$$
权重 $g_i(x)$ 经过 re-normalize(重归一化为和为 1),确保输出的尺度与单个 FFN 一致。
8.3 负载均衡
MoE 训练中的一个棘手问题是负载不均衡——路由器可能倾向于把大部分 token 都发送给少数几个”明星专家”,导致其他专家得不到训练、逐渐退化,进而加剧不均衡,形成恶性循环。
常用的缓解策略包括:
辅助损失(Auxiliary Loss / Load Balancing Loss)
在训练目标中加入一项额外的损失函数,惩罚负载不均匀的情况。常见形式是:
$$
L_{aux} = \alpha \cdot n_{experts} \cdot \sum_{i=1}^{n_{experts}} f_i \cdot p_i
$$
其中 $f_i$ 是专家 $i$ 实际接收的 token 比例,$p_i$ 是路由器分配给专家 $i$ 的平均概率,$\alpha$ 是平衡系数。当所有专家的负载均匀时,该损失取最小值。
容量因子(Capacity Factor)
为每个专家设定一个最大接收 token 数。超出容量的 token 会被丢弃(送入残差路径)或重新路由到其他专家。这从硬性约束的角度防止单个专家过载。
Expert Choice 路由
与传统的 token-choose-expert(由 token 选择专家)不同,Expert Choice 让每个专家主动选择要处理的 top-k 个 token。这天然保证每个专家处理相同数量的 token,从根本上解决负载不均衡的问题。
8.4 AI Infra 关联:Expert Parallelism
MoE 引入了一种新的并行维度——专家并行(Expert Parallelism, EP)。由于每个专家是独立的 FFN,可以将不同的专家放置在不同的 GPU 上。
这带来了独特的通信模式:传统的张量并行使用 AllReduce(所有卡都参与),而 MoE 的专家并行使用 All-to-All 通信——每个 GPU 需要把分配给其他 GPU 上专家的 token 发送过去,同时接收分配给本地专家的 token。这种通信模式对网络拓扑和通信库的要求与 AllReduce 截然不同,是 MoE 系统设计中的核心挑战。
Mixtral 8x7B、DeepSeek-V2 等模型的成功部署,背后都需要精心设计的 EP 调度策略和通信优化。
9. CUDA Kernel 融合在 FFN 中的应用
9.1 为什么需要 Kernel 融合
GPU 执行计算的基本单位是 kernel——每次从 CPU 调度一个 kernel 到 GPU 上执行。每次 kernel 启动(launch)都有固定的调度开销(通常几微秒),而且每个独立的 kernel 都需要从 HBM(高带宽显存)读取输入、写回输出。
FFN 的前向传播包含多个步骤:GEMM、激活函数、逐元素乘法等。如果每个步骤都作为独立的 kernel 执行,中间结果就需要反复在 HBM 和寄存器/SRAM 之间搬运。对于 SwiGLU FFN,一次前向传播的逐元素操作包括:
- Swish 激活:$\text{gate} = x_gate \cdot \sigma(x_gate)$(涉及 sigmoid 和逐元素乘法)
- 门控乘法:$\text{hidden} = \text{gate} \otimes \text{up}$
如果分开执行,gate 向量需要写入 HBM 后再读回来做乘法。融合后,这些操作可以在一个 kernel 内完成——gate 的中间结果驻留在寄存器中,直接与 up 相乘,完全不经过 HBM。
9.2 FFN 中的典型融合场景
场景一:Swish 激活 + 门控乘法融合
将 Swish(gate) * up 融合为一个 kernel。两个 $(N, d_{ff})$ 的输入产生一个 $(N, d_{ff})$ 的输出,HBM 读写量从 4 次降为 3 次(读 2 次 + 写 1 次,省去了 Swish 中间结果的写入和读回)。
场景二:GEMM + 激活融合(Epilogue Fusion)
在 GEMM kernel 的”收尾阶段”(epilogue)中嵌入激活函数。以 $W_{gate}$ 的 GEMM 为例,标准流程是:
- kernel A: $x_gate = x \cdot W_{gate}$,写入 HBM
- kernel B: $\text{gate} = \text{Swish}(x_gate)$,从 HBM 读出 $x_gate$ 并写回 gate
融合后,在 GEMM kernel 的最后一步(将结果写回 HBM 之前),直接在寄存器中对每个输出元素施加 Swish,然后写入 HBM。这样只需一个 kernel 和一次 HBM 写入。
cuBLAS 和 CUTLASS 等 GEMM 库都提供了 epilogue fusion 的接口,支持在 GEMM 输出后拼接自定义的逐元素操作。
场景三:$W_{gate}$ 和 $W_{up}$ 的 GEMM 合并
$W_{gate}$ 和 $W_{up}$ 的输入都是同一个 $x$,可以将两个矩阵沿列方向拼接为一个 $(d_{model}, 2 \times d_{ff})$ 的大矩阵,执行一次大 GEMM,然后将输出切分为 gate 和 up 两部分。这将两次 GEMM 合并为一次,减少 kernel launch 开销,同时大 GEMM 的 GPU 利用率通常优于两个小 GEMM。
1 | # 分开执行(2 次 GEMM launch): |
场景四:偏置加法 + 残差加法 + LayerNorm 融合
FFN 的输出通常要经过残差加法和后续的 LayerNorm。将这些逐元素操作融合为一个 kernel,可以将 FFN 输出、残差输入和 LayerNorm 参数一次性读入,在片上完成加法和归一化后写回,避免多次 HBM 往返。
9.3 性能影响
在实际大模型推理中,FFN 的 kernel 融合带来的加速效果因场景而异:
- Decode 阶段(每次只处理 1 个 token):GEMM 退化为矩阵-向量乘法,计算量小,kernel launch 和 HBM 带宽是主要瓶颈,融合收益显著(可达 20-40% 的加速)
- Prefill 阶段(处理整段 prompt):GEMM 的计算量占主导,逐元素操作的 HBM 开销占比较小,融合收益相对有限(通常 5-15%)
🎯 自我检验清单
完成本文学习后,检验自己是否真正理解了 FFN 模块:
- 能说清 FFN 在 Transformer Block 中的角色——为什么 Attention 之后还需要 FFN,FFN 负责什么
- 能在白板上画出标准 FFN 和 SwiGLU FFN 的结构图,标注每个矩阵的维度和数据流方向
- 能手算标准 FFN 和 SwiGLU FFN 的参数量,并解释为什么两者参数量几乎相同
- 能解释为什么 FFN 的中间维度选择 4x(标准)和 8/3x(SwiGLU),以及 LLaMA 中 11008 这个数值的由来
- 能分别说清 ReLU、GELU、Swish 三种激活函数的数学定义、直觉含义和各自优缺点
- 能解释 ReLU 的神经元死亡问题为什么会发生,以及 GELU 和 Swish 如何避免这个问题
- 能解释 SwiGLU 的门控机制——$W_{gate}$ 和 $W_{up}$ 各自的作用,为什么门控比统一激活更有效
- 能描述 Megatron-LM 对 FFN 做张量并行的切分方式——列切分 $W_{up}$/$W_{gate}$,行切分 $W_{down}$,为什么这样切分只需要一次 AllReduce
- 能概述 MoE 的基本思想——多个专家 FFN、路由机制、负载均衡,以及 Expert Parallelism 的通信模式
- 能举出至少两个 FFN 中 CUDA kernel 融合的具体场景
📚 参考资料
论文
- Attention Is All You Need (Vaswani et al., 2017): https://arxiv.org/abs/1706.03762 – Transformer 原始论文,FFN 结构的来源
- GLU Variants Improve Transformer (Shazeer, 2020): https://arxiv.org/abs/2002.05202 – SwiGLU 等门控激活函数的提出与实验对比
- Gaussian Error Linear Units (GELUs) (Hendrycks & Gimpel, 2016): https://arxiv.org/abs/1606.08415 – GELU 激活函数
- Searching for Activation Functions (Ramachandran et al., 2017): https://arxiv.org/abs/1710.05941 – Swish 激活函数的发现
- Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism (Shoeybi et al., 2019): https://arxiv.org/abs/1909.08053 – 张量并行的切分策略
- Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity (Fedus et al., 2022): https://arxiv.org/abs/2101.03961 – MoE 架构的简化与大规模应用
- LLaMA: Open and Efficient Foundation Language Models (Touvron et al., 2023): https://arxiv.org/abs/2302.13971 – SwiGLU 在 LLaMA 中的应用
- Mixtral of Experts (Jiang et al., 2024): https://arxiv.org/abs/2401.04088 – 稀疏 MoE 架构
教程与博客
- The Illustrated Transformer (Jay Alammar): https://jalammar.github.io/illustrated-transformer/ – 图文并茂的 Transformer 入门
- Andrej Karpathy: Let’s build GPT from scratch: https://www.youtube.com/watch?v=kCc8FmEb1nY – 从零手写 GPT
- The Annotated Transformer (Harvard NLP): https://nlp.seas.harvard.edu/annotated-transformer/ – 论文逐行对应 PyTorch 实现