2.4 Transformer前馈网络FFN深入理解

前馈网络(FFN)是 Transformer 中参数量最大的模块,占据单个 Block 约 2/3 的参数。它负责对每个 token 进行独立的非线性变换,是模型”记忆知识”和”深度推理”的核心载体。本文从 FFN 的结构设计出发,深入剖析激活函数演进、参数量计算、SwiGLU 门控机制,并延伸到张量并行切分、MoE 专家并行和 CUDA kernel 融合等 AI Infra 工程实践。

📑 目录


1. FFN 在 Transformer 中的角色定位

要理解 FFN 的作用,先回顾一下 Transformer Block 的整体结构。每个 Block 包含两个核心子模块:Self-Attention 和 FFN,两者通过残差连接和 LayerNorm 串联起来。

Self-Attention 的职责是信息交互——让序列中的每个 token 能够”看见”其他 token 并汇聚相关信息。打个比方,如果把一个 Transformer Block 看作一家工厂的生产流水线,Self-Attention 就是流水线上的”原料分拣站”:它根据订单需求(Query)从仓库中各个货架(Key)上挑选相关原料(Value),把需要的材料汇集到一起。

而 FFN 则是紧随其后的”加工车间”。原料分拣完毕后,需要在加工车间里对每份原料进行独立的深加工——切割、熔炼、塑形——把粗加工的信息变成精细的、有结构的产出。关键在于:加工车间里的每条生产线是独立运作的,每个 token 各走各的生产线,彼此之间不再交互。

用更技术性的语言来说:

  • Self-Attention 是 token 间的运算——输出依赖于序列中所有 token 的组合关系
  • FFN 是 token 内的运算——对每个 token 的表示向量独立施加相同的非线性变换

这种”先交互、后加工”的两步设计有深层次的意义。仅靠线性的信息汇聚(Attention 本质上是加权求和,是一种线性运算),模型的表达能力会受到严重限制。FFN 引入的非线性变换赋予了模型逼近任意复杂函数的能力——根据万能近似定理(Universal Approximation Theorem),一个带有非线性激活函数的两层前馈网络可以逼近任意连续函数。

近年来的研究还发现,FFN 在大模型中扮演着”知识库”的角色。大量的事实性知识(”巴黎是法国的首都”、”水的分子式是 H2O”)被编码在 FFN 的参数中,而 Attention 层更多负责组织和提取这些知识。这也解释了为什么 FFN 占据了模型大部分的参数量——它需要足够大的”存储容量”来记忆海量知识。


2. FFN 的”展开-压缩”结构

2.1 基本结构

标准 FFN 的数学表达式非常简洁:

$$
\text{FFN}(x) = W_2 \cdot \sigma(W_1 \cdot x + b_1) + b_2
$$

其中:

  • $x$ 是输入向量,维度为 $d_{model}$
  • $W_1$ 的形状为 $(d_{model}, d_{ff})$,将维度从 $d_{model}$ 升高到 $d_{ff}$
  • $\sigma$ 是非线性激活函数
  • $W_2$ 的形状为 $(d_{ff}, d_{model})$,将维度从 $d_{ff}$ 压缩回 $d_{model}$
  • $b_1, b_2$ 是偏置项(现代大模型通常省略偏置)

以 $d_{model} = 4096$ 为例,如果 $d_{ff} = 4 \times d_{model} = 16384$,数据流的维度变化为:

阶段 维度 说明
输入 $(N, 4096)$ 原始表示空间
升维 $(N, 16384)$ 高维展开空间
激活 $(N, 16384)$ 非线性变换
降维 $(N, 4096)$ 压缩回原始维度

2.2 为什么采用”先升维后降维”的瓶颈结构

这个”展开-压缩”的设计模式并非随意之举,背后有清晰的信息处理逻辑。

高维空间中线性可分性更好。 这是统计学习中的经典原理。想象你有一堆红球和蓝球混在一条线上(一维),很难找到一个点把它们完美分开。但如果你把这些球抛到一个三维空间中(升维),它们就更容易被一个平面分开。FFN 的第一层 $W_1$ 就是在做这件事——把信息投射到一个更高维的空间,让不同的特征模式在高维空间中更容易被区分和操作。

激活函数在高维空间中更有效。 非线性激活函数的作用是选择性地保留或抑制某些信息通道。在高维空间中,每个维度可以被看作一个独立的”特征检测器”,激活函数决定哪些检测器被激活、哪些被关闭。维度越高,可以同时检测的特征模式越多,模型的表达能力就越强。

降维是信息蒸馏。 第二层 $W_2$ 将高维空间中被激活函数筛选过的信息压缩回原始维度。这个过程迫使模型只保留最重要的信息——就像把一锅汤熬浓,去掉多余的水分,留下精华。

与 Autoencoder 的类比。 如果你熟悉自编码器(Autoencoder),FFN 的结构正好是反过来的。Autoencoder 先压缩再还原(编码-解码),用来学习数据的低维表示;FFN 先展开再压缩(展开-蒸馏),用来在高维空间中进行复杂的非线性变换后提取有用信息。

2.3 PyTorch 实现:标准 FFN

下面是一个带注释的标准 FFN 实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch
import torch.nn as nn
import torch.nn.functional as F

class StandardFFN(nn.Module):
"""标准前馈网络:升维 -> 激活 -> 降维"""

def __init__(self, d_model: int, d_ff: int, activation: str = "relu"):
"""
Args:
d_model: 模型隐藏维度,如 4096
d_ff: FFN 中间维度,通常为 4 * d_model,如 16384
activation: 激活函数类型,可选 "relu" 或 "gelu"
"""
super().__init__()

# 升维投影:(d_model, d_ff),将表示从原始空间映射到高维空间
self.w_up = nn.Linear(d_model, d_ff, bias=False)

# 降维投影:(d_ff, d_model),将高维空间中处理后的信息压缩回原始维度
self.w_down = nn.Linear(d_ff, d_model, bias=False)

# 选择激活函数
if activation == "relu":
self.activation = nn.ReLU()
elif activation == "gelu":
self.activation = nn.GELU()
else:
raise ValueError(f"不支持的激活函数: {activation}")

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
前向传播
Args:
x: 输入张量,形状 (batch_size, seq_len, d_model)
Returns:
输出张量,形状 (batch_size, seq_len, d_model)
"""
# Step 1: 升维 (batch_size, seq_len, d_model) -> (batch_size, seq_len, d_ff)
h = self.w_up(x)

# Step 2: 非线性激活,在高维空间中进行特征筛选
h = self.activation(h)

# Step 3: 降维 (batch_size, seq_len, d_ff) -> (batch_size, seq_len, d_model)
output = self.w_down(h)

return output

# 使用示例
d_model = 4096
d_ff = 4 * d_model # 16384

ffn = StandardFFN(d_model, d_ff, activation="gelu")
x = torch.randn(2, 128, d_model) # batch=2, seq_len=128
out = ffn(x) # 输出形状: (2, 128, 4096)
print(f"输入形状: {x.shape}, 输出形状: {out.shape}")
# 参数量: 4096*16384 + 16384*4096 = 134,217,728 (约 134M)
print(f"参数量: {sum(p.numel() for p in ffn.parameters()):,}")

3. 激活函数详解

激活函数是 FFN 中至关重要的组件。没有激活函数,两层线性变换的复合仍然是线性变换($W_2 \cdot W_1 \cdot x$ 等价于一个 $(W_2 W_1) \cdot x$ 的单层线性变换),FFN 就失去了存在的意义。激活函数引入的非线性是模型表达能力的关键来源。

3.1 ReLU(Rectified Linear Unit)

数学公式

$$
\text{ReLU}(x) = \max(0, x)
$$

直觉解释

ReLU 是最朴素的”开关”函数——输入为正就原样通过,输入为负就直接关闭。你可以把它想象成一个单向阀门:水(信号)只能从正方向流过,负方向完全截断。

输出范围和特性

  • 输出范围:$[0, +\infty)$
  • 正半轴梯度恒为 1,不存在梯度衰减问题
  • 负半轴梯度恒为 0,输出恒为 0
  • 计算极其高效:只需一次比较操作

优点

  • 计算简单,硬件友好——在 GPU 上只需一个 max 操作
  • 正半轴梯度为 1,缓解了深层网络的梯度消失问题(相比 sigmoid/tanh)
  • 引入稀疏性:负值被置零,激活后的向量通常有大量零元素,有利于特征的稀疏表示

缺点:神经元死亡问题(Dying ReLU)

这是 ReLU 最严重的缺陷,值得详细讨论。

当某个神经元的输入在训练过程中恰好落入负区间,ReLU 的输出为 0,梯度也为 0。零梯度意味着该神经元的权重在反向传播时得不到任何更新信号。而权重没有更新,下一轮前向传播时输入大概率仍然为负——于是这个神经元永久性地”死亡”了,不再对模型的计算做出任何贡献。

这个问题在学习率较大时尤为严重。一次大幅度的梯度更新可能把权重推到一个”坏”的区域,使得该神经元对所有训练样本的输出都为负,从此一蹶不振。在实际训练中,有时候会观察到超过 40% 的 ReLU 神经元处于”死亡”状态,这意味着模型有效容量大幅缩水。

LeakyReLU 通过给负半轴一个微小的斜率(如 0.01x)来缓解这个问题,但这只是打了个补丁,并没有从根本上解决激活函数在零点处不光滑、不可微的问题。

3.2 GELU(Gaussian Error Linear Unit)

数学公式

$$
\text{GELU}(x) = x \cdot \Phi(x) = x \cdot \frac{1}{2}\left[1 + \text{erf}\left(\frac{x}{\sqrt{2}}\right)\right]
$$

其中 $\Phi(x)$ 是标准正态分布的累积分布函数(CDF),$\text{erf}$ 是误差函数。

实际计算中常用近似公式:

$$
\text{GELU}(x) \approx 0.5x\left[1 + \tanh\left(\sqrt{\frac{2}{\pi}}(x + 0.044715x^3)\right)\right]
$$

直觉解释:概率化门控

GELU 的核心思想是把激活看作一种”概率性门控”。对于输入 $x$,GELU 计算的是”以概率 $\Phi(x)$ 保留 $x$,以概率 $1-\Phi(x)$ 将其置零”的期望值。

为什么用正态分布的 CDF 作为保留概率?这基于一个直觉假设:如果我们认为神经网络中的输入信号近似服从正态分布(根据中心极限定理,多个随机变量之和趋向正态分布),那么一个输入值越大(越偏离均值向右),它越可能是”有意义的信号”而非噪声,就越应该被保留。$\Phi(x)$ 恰好提供了这样一个随输入值单调递增的保留概率。

与 ReLU 的硬切换(0 或全通过)相比,GELU 实现了一种”软门控”——接近零的输入不会被完全截断,而是按照一个与其大小成正比的概率被保留。

输出范围和特性

  • 输出范围:约 $[-0.17, +\infty)$——负半轴有一个很小的负值区域
  • 处处光滑可微,没有 ReLU 在零点处的不可微问题
  • 负值区域输出非零但极小,梯度也非零,不存在神经元死亡问题
  • 近似于 ReLU 但更平滑:大正值时 GELU(x) 趋近 x,大负值时趋近 0

优缺点

优点:

  • 平滑可微,训练更稳定
  • 概率化门控使信息保留更加精细
  • 在 BERT、GPT-2 等模型中被验证效果优于 ReLU

缺点:

  • 计算量比 ReLU 大(涉及 erf 函数或 tanh 近似)
  • 但在现代 GPU 上,计算开销差异相对于 GEMM 而言几乎可忽略

3.3 Swish

数学公式

$$
\text{Swish}(x) = x \cdot \sigma(\beta x) = \frac{x}{1 + e^{-\beta x}}
$$

其中 $\sigma$ 是 sigmoid 函数,$\beta$ 是一个可学习或固定的参数。当 $\beta=1$ 时,通常简写为 $\text{SiLU}(x) = x \cdot \sigma(x)$。

直觉解释

Swish 可以看作是 ReLU 和线性函数之间的一种自适应插值。sigmoid 函数 $\sigma(x)$ 输出一个 0 到 1 之间的”门控值”,乘以 $x$ 本身,就实现了”用输入自身来控制信息通过量”的效果——这就是所谓的”自门控”(self-gating)。

当 $x$ 为大正值时,$\sigma(x)$ 趋近 1,$\text{Swish}(x) \approx x$,行为接近线性;当 $x$ 为大负值时,$\sigma(x)$ 趋近 0,$\text{Swish}(x) \approx 0$,行为接近截断。有趣的是,当 $x$ 在零附近的小负值时,Swish 的输出会略低于零——它有一个小的”凹坑”,最小值约为 $-0.278$(在 $x \approx -1.278$ 处取到)。

输出范围和特性

  • 输出范围:约 $[-0.278, +\infty)$
  • 处处光滑可微
  • 非单调函数——在负半轴有一个小的下凹区域
  • 当 $\beta \to \infty$ 时,Swish 退化为 ReLU;当 $\beta = 0$ 时,退化为线性函数 $x/2$

平滑性优势

Swish 的平滑性在训练中带来了实质性的好处。ReLU 在 $x=0$ 处梯度不连续(左导数为 0,右导数为 1),这意味着优化过程中的损失曲面(loss landscape)存在”棱角”。梯度下降算法在这些不光滑的点附近容易产生震荡。

Swish 的处处可微性质使损失曲面更加光滑,梯度信号在整个定义域上连续变化,优化过程更加稳定。Google Brain 的研究表明,Swish 在深层网络中的表现一致优于 ReLU,尤其是在网络深度超过 40 层时优势更为明显。

3.4 对比总结

特性 ReLU GELU Swish/SiLU
公式 $\max(0,x)$ $x \cdot \Phi(x)$ $x \cdot \sigma(x)$
输出范围 $[0,+\infty)$ $\approx[-0.17,+\infty)$ $\approx[-0.278,+\infty)$
零点可微
单调性 否(近似单调)
负值输出 极小
神经元死亡
计算成本 最低
典型应用 原始 Transformer BERT, GPT-2 LLaMA (SwiGLU 中)

4. SwiGLU:当前大模型的主流选择

4.1 GLU 门控机制的思想

在讲 SwiGLU 之前,先理解 GLU(Gated Linear Unit,门控线性单元)的思想。

标准 FFN 对输入施加的是”统一的”激活函数——所有维度经过相同的非线性变换。GLU 提出了一种不同的思路:让模型自己学习”哪些信息该通过、哪些该被抑制”。

GLU 的核心公式是:

$$
\text{GLU}(x) = (W_{up} \cdot x) \otimes \sigma(W_{gate} \cdot x)
$$

其中 $\otimes$ 表示逐元素相乘,$\sigma$ 是 sigmoid 函数。

这里引入了两个独立的线性投影:

  • $W_{up} \cdot x$:生成”候选内容”——这些信息准备通过
  • $\sigma(W_{gate} \cdot x)$:生成”门控信号”——每个维度的通过概率(0 到 1 之间)

两者逐元素相乘,门控信号大的维度信息被保留,门控信号小的维度信息被抑制。这比统一施加一个激活函数要灵活得多——模型可以针对不同的输入内容,学习不同的门控策略。

4.2 SwiGLU 的完整定义

SwiGLU 是 Noam Shazeer 在 2020 年的论文”GLU Variants Improve Transformer”中提出的 GLU 变体,将门控中的 sigmoid 替换为 Swish 函数:

$$
\text{SwiGLU}(x) = (\text{Swish}(W_{gate} \cdot x)) \otimes (W_{up} \cdot x)
$$

完整的 SwiGLU FFN 包含三个权重矩阵:

$$
\text{FFN}{SwiGLU}(x) = W{down} \cdot \left[\text{Swish}(W_{gate} \cdot x) \otimes (W_{up} \cdot x)\right]
$$

其中:

  • $W_{gate}$:形状 $(d_{model}, d_{ff})$,门控投影,其输出经过 Swish 激活后产生门控信号
  • $W_{up}$:形状 $(d_{model}, d_{ff})$,内容投影,生成候选通过的信息
  • $W_{down}$:形状 $(d_{ff}, d_{model})$,降维投影,将结果压缩回原始维度

4.3 为什么门控机制有效

门控机制的有效性可以从几个角度理解。

更精细的信息筛选。 标准 FFN 中,激活函数对升维后的每个维度施加相同的变换规则——一个全局的非线性函数。SwiGLU 的门控机制则允许模型对每个维度施加不同的”通过/抑制”决策。这意味着模型可以学到更加条件化的计算——“当输入具有特征 A 时,保留这些维度;当输入具有特征 B 时,保留那些维度。”

更丰富的梯度信号。 由于门控路径和内容路径是两个独立的线性投影,反向传播时梯度可以沿着两条路径分别回传,参数的更新信号更加丰富。这可以理解为一种隐式的”梯度高速公路”——与残差连接的思想异曲同工。

实验验证。 Shazeer 的论文在相同参数量的条件下对比了多种 GLU 变体,SwiGLU 在多个下游任务上一致地取得了最佳表现。后来 LLaMA、Mistral、Qwen、DeepSeek 等主流大模型纷纷采用 SwiGLU,在工业实践中进一步验证了其有效性。

4.4 PyTorch 实现:SwiGLU FFN

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
import torch.nn as nn
import torch.nn.functional as F

class SwiGLUFFN(nn.Module):
"""SwiGLU 前馈网络:门控投影 + 内容投影 -> 逐元素相乘 -> 降维"""

def __init__(self, d_model: int, d_ff: int):
"""
Args:
d_model: 模型隐藏维度,如 4096
d_ff: FFN 中间维度,SwiGLU 通常使用 (8/3) * d_model,如 11008
"""
super().__init__()

# 门控投影:(d_model, d_ff)
# 输出经过 Swish 激活后,作为门控信号控制信息通过量
self.w_gate = nn.Linear(d_model, d_ff, bias=False)

# 内容投影:(d_model, d_ff)
# 生成候选通过的信息内容
self.w_up = nn.Linear(d_model, d_ff, bias=False)

# 降维投影:(d_ff, d_model)
# 将门控筛选后的高维信息压缩回原始维度
self.w_down = nn.Linear(d_ff, d_model, bias=False)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
前向传播
Args:
x: 输入张量,形状 (batch_size, seq_len, d_model)
Returns:
输出张量,形状 (batch_size, seq_len, d_model)
"""
# Step 1: 门控路径 — 投影到高维空间后经 Swish 激活
# (batch_size, seq_len, d_model) -> (batch_size, seq_len, d_ff)
gate = F.silu(self.w_gate(x)) # silu 即 Swish(x) = x * sigmoid(x)

# Step 2: 内容路径 — 投影到高维空间(不经过激活函数)
# (batch_size, seq_len, d_model) -> (batch_size, seq_len, d_ff)
up = self.w_up(x)

# Step 3: 门控乘法 — 门控信号与内容逐元素相乘
# 门控值大的维度保留信息,门控值小的维度抑制信息
# (batch_size, seq_len, d_ff)
hidden = gate * up

# Step 4: 降维 — 压缩回原始维度
# (batch_size, seq_len, d_ff) -> (batch_size, seq_len, d_model)
output = self.w_down(hidden)

return output

# 使用示例:LLaMA-2-7B 配置
d_model = 4096
d_ff = 11008 # (8/3) * 4096,取整到 256 的倍数

swiglu_ffn = SwiGLUFFN(d_model, d_ff)
x = torch.randn(2, 128, d_model) # batch=2, seq_len=128
out = swiglu_ffn(x) # 输出形状: (2, 128, 4096)
print(f"输入形状: {x.shape}, 输出形状: {out.shape}")
# 参数量: 4096*11008 * 2 (gate + up) + 11008*4096 (down) = 135,266,304 (约 135M)
print(f"参数量: {sum(p.numel() for p in swiglu_ffn.parameters()):,}")

5. 参数量计算与对比

FFN 的参数量在整个 Transformer Block 中占据主导地位。准确计算参数量是显存规划和并行策略设计的基础。

5.1 标准 FFN vs SwiGLU FFN 参数量对比

以 $d_{model} = 4096$ 为例,分别计算两种 FFN 结构的参数量。

标准 FFN($d_{ff} = 4 \times d_{model} = 16384$)

参数矩阵 形状 参数量
$W_{up}$ (4096, 16384) 67,108,864
$W_{down}$ (16384, 4096) 67,108,864
合计 134,217,728 (134M)

SwiGLU FFN($d_{ff} = \frac{8}{3} \times d_{model} \approx 10922$,实际取 11008)

参数矩阵 形状 参数量
$W_{gate}$ (4096, 11008) 45,088,768
$W_{up}$ (4096, 11008) 45,088,768
$W_{down}$ (11008, 4096) 45,088,768
合计 135,266,304 (135M)

5.2 对比分析

对比维度 标准 FFN SwiGLU FFN
权重矩阵数量 2 3
中间维度 $4 \times d_{model}$ $\frac{8}{3} \times d_{model}$
总参数量 $8 \times d_{model}^2$ $3 \times d_{model} \times \frac{8}{3} d_{model} = 8 \times d_{model}^2$
以 $d_{model}=4096$ 计算 134M 135M
GEMM 次数(前向) 2 3
激活函数 全局施加 门控选择性施加

关键发现:两种结构的总参数量几乎相同($8 \times d_{model}^2$),但 SwiGLU 通过门控机制实现了更好的效果。代价是前向传播需要 3 次 GEMM 而非 2 次,计算量略有增加。

5.3 与 Attention 模块的参数量对比

模块 参数量 占 Block 比例
Self-Attention ($W_Q, W_K, W_V, W_O$) $4 \times d_{model}^2 = 67M$ ~33%
FFN(标准或 SwiGLU) $8 \times d_{model}^2 = 135M$ ~67%
LayerNorm $\times 2$ $4 \times d_{model} = 16K$ <0.01%
单 Block 合计 ~201M

FFN 的参数量约为 Attention 的 2 倍,占单个 Block 参数总量的约 2/3。这意味着任何针对 FFN 的优化(并行切分、量化、专家化)都将对整体效率产生显著影响。


6. FFN 中间维度的选择:为什么是 4x 和 8/3x

6.1 标准 FFN 为什么选择 4 倍扩展

$d_{ff} = 4 \times d_{model}$ 的选择来自原始 Transformer 论文”Attention Is All You Need”。Vaswani 等人设定 $d_{model}=512, d_{ff}=2048$,形成了 4 倍扩展比。这个比例后来成为了事实上的标准。

为什么是 4 倍而不是 2 倍或 8 倍?这在很大程度上是经验性的选择。从直觉上说:

  • 太小(如 2x):高维空间不够大,非线性变换的表达能力受限,模型效果下降
  • 太大(如 8x):参数量急剧膨胀,但效果提升边际递减,参数效率降低
  • 4x 是一个平衡点:在模型效果和参数效率之间取得较好的折中

后续研究(如 Kaplan 等人关于 Scaling Laws 的工作)也表明,在固定总参数量的前提下,$d_{ff} / d_{model}$ 的比值在 4 附近是近似最优的。这个比值过大或过小都会导致模型在同等参数量下表现变差。

6.2 SwiGLU 为什么用 8/3 倍扩展

SwiGLU 将标准 FFN 的 2 个矩阵变成了 3 个矩阵。如果继续使用 $d_{ff} = 4 \times d_{model}$,总参数量会变成:

$$
3 \times d_{model} \times 4d_{model} = 12 \times d_{model}^2
$$

这比标准 FFN 的 $8 \times d_{model}^2$ 多了 50% 的参数。为了在引入门控机制的同时保持总参数量不变,需要调整中间维度:

$$
3 \times d_{model} \times d_{ff} = 8 \times d_{model}^2
$$

$$
d_{ff} = \frac{8}{3} \times d_{model} \approx 2.667 \times d_{model}
$$

以 $d_{model} = 4096$ 计算:$d_{ff} = \frac{8}{3} \times 4096 = 10922.67$。

在工程实现中,$d_{ff}$ 通常会被取整到特定数值的倍数(如 128 或 256 的倍数),以确保矩阵维度能够被 GPU 的 Tensor Core 高效处理。Tensor Core 在执行矩阵乘法时要求维度满足特定对齐要求(例如 FP16 下通常是 8 的倍数,实践中对齐到 128 或 256 可以获得更好的利用率)。

以 LLaMA-2-7B 为例:$10922.67$ 取整到 $11008$($= 43 \times 256$),这就是我们看到的 LLaMA 配置中 $d_{ff} = 11008$ 的由来。


7. 张量并行切分 FFN

当模型规模大到单张 GPU 放不下时,需要将模型参数切分到多张 GPU 上——这就是张量并行(Tensor Parallelism, TP)。FFN 占据 Block 参数的 2/3,如何高效切分 FFN 直接决定了张量并行的整体效率。

7.1 Megatron-LM 的列切分 + 行切分策略

Megatron-LM 提出了一种精巧的切分方案,核心思想是:**$W_{up}$(和 $W_{gate}$)按列切分,$W_{down}$ 按行切分**,使得中间结果不需要通信,只在最终输出时做一次 AllReduce。

以 2 张 GPU 为例,标准 FFN 的切分如下:

第一步:$W_{up}$ 列切分

将 $W_{up}$ 的列(输出维度)均匀分配给两张 GPU:

1
2
GPU 0: W_up[:, :d_ff/2]    形状 (d_model, d_ff/2)
GPU 1: W_up[:, d_ff/2:] 形状 (d_model, d_ff/2)

每张 GPU 各自计算自己那部分的升维结果:

1
2
GPU 0: h_0 = activation(x @ W_up_0)    形状 (N, d_ff/2)
GPU 1: h_1 = activation(x @ W_up_1) 形状 (N, d_ff/2)

注意:这里 $x$ 需要在两张 GPU 上各有一份完整拷贝(通过前一步的 AllReduce 或广播获得)。列切分的关键优势在于:激活函数可以在切分后独立施加——因为激活函数是逐元素操作,不依赖其他维度的值,所以对完整向量施加激活等价于对各部分分别施加激活。

第二步:$W_{down}$ 行切分

将 $W_{down}$ 的行(输入维度)均匀分配:

1
2
GPU 0: W_down[:d_ff/2, :]    形状 (d_ff/2, d_model)
GPU 1: W_down[d_ff/2:, :] 形状 (d_ff/2, d_model)

每张 GPU 用自己的中间结果乘以自己的 $W_{down}$ 部分:

1
2
GPU 0: out_0 = h_0 @ W_down_0    形状 (N, d_model)
GPU 1: out_1 = h_1 @ W_down_1 形状 (N, d_model)

第三步:AllReduce 求和

最终输出等于两张 GPU 的部分结果之和:

1
output = AllReduce(out_0, out_1) = out_0 + out_1    形状 (N, d_model)

这是因为:

$$
W_{down} \cdot h = \begin{bmatrix} W_{down,0} \ W_{down,1} \end{bmatrix}^T \cdot \begin{bmatrix} h_0 \ h_1 \end{bmatrix} = W_{down,0}^T \cdot h_0 + W_{down,1}^T \cdot h_1
$$

7.2 SwiGLU 的切分

对于 SwiGLU FFN,$W_{gate}$ 和 $W_{up}$ 都按列切分,$W_{down}$ 按行切分。逻辑完全一致:

1
2
3
4
5
6
7
8
9
10
11
GPU 0: gate_0 = Swish(x @ W_gate_0)    # (N, d_ff/2)
up_0 = x @ W_up_0 # (N, d_ff/2)
mid_0 = gate_0 * up_0 # 逐元素相乘,(N, d_ff/2)
out_0 = mid_0 @ W_down_0 # (N, d_model)

GPU 1: gate_1 = Swish(x @ W_gate_1) # (N, d_ff/2)
up_1 = x @ W_up_1 # (N, d_ff/2)
mid_1 = gate_1 * up_1 # (N, d_ff/2)
out_1 = mid_1 @ W_down_1 # (N, d_model)

output = AllReduce(out_0, out_1) # (N, d_model)

整个 FFN 在前向传播中只需要一次 AllReduce。而每张 GPU 上的三次 GEMM(gate/up/down)和逐元素操作都是本地计算,不需要通信。这使得张量并行在 FFN 上的通信开销非常低——每个 Block 的 FFN 只贡献一次 AllReduce(Attention 也贡献一次,所以每个 Block 共两次 AllReduce)。

7.3 工程考量

在实际部署中,$d_{ff}$ 需要能被 TP 的 GPU 数量整除。例如,$d_{ff} = 11008$ 在 TP=2 时切分为每卡 5504,TP=4 时每卡 2752,TP=8 时每卡 1376——都能整除。这也是 $d_{ff}$ 选取时会优先考虑 2 的幂次或者 128/256 倍数的原因之一。


8. MoE:将 FFN 拆分为多个专家

8.1 核心思想

混合专家模型(Mixture of Experts, MoE)是一种扩大模型容量而不等比例增加计算量的技术,其核心改造对象正是 FFN 层。

基本思路非常直观:与其使用一个巨大的 FFN,不如将其拆分为多个较小的”专家”(Expert),每个专家都是一个独立的 FFN。对于每个输入 token,只激活其中少数几个专家进行计算,其余专家保持”休眠”。

例如,一个 MoE 层可能包含 64 个专家 FFN,但每个 token 只选择其中 2 个专家(Top-2)进行前向计算。这意味着:

  • 模型总参数量 = 64 个专家的参数量之和(很大)
  • 每个 token 的计算量 = 2 个专家的计算量(很小)

这就实现了”参数量大但计算量小”的目标——模型有足够的容量存储知识,但推理和训练时的计算成本可控。

8.2 路由机制(Router)

MoE 的关键组件是路由器(Router / Gate),它决定每个 token 应该被发送到哪些专家。

最常见的路由方式是 Top-K 路由:

$$
g(x) = \text{softmax}(W_r \cdot x)
$$

$$
\text{TopK}(g(x)) \to \text{选出概率最大的 K 个专家}
$$

其中 $W_r$ 是路由器的权重矩阵,形状为 $(d_{model}, n_{experts})$。路由器将每个 token 的表示向量映射到一个 $n_{experts}$ 维的概率分布,概率最大的 K 个专家被选中。

最终输出是被选中专家输出的加权和:

$$
\text{MoE}(x) = \sum_{i \in \text{TopK}} g_i(x) \cdot \text{Expert}_i(x)
$$

权重 $g_i(x)$ 经过 re-normalize(重归一化为和为 1),确保输出的尺度与单个 FFN 一致。

8.3 负载均衡

MoE 训练中的一个棘手问题是负载不均衡——路由器可能倾向于把大部分 token 都发送给少数几个”明星专家”,导致其他专家得不到训练、逐渐退化,进而加剧不均衡,形成恶性循环。

常用的缓解策略包括:

辅助损失(Auxiliary Loss / Load Balancing Loss)

在训练目标中加入一项额外的损失函数,惩罚负载不均匀的情况。常见形式是:

$$
L_{aux} = \alpha \cdot n_{experts} \cdot \sum_{i=1}^{n_{experts}} f_i \cdot p_i
$$

其中 $f_i$ 是专家 $i$ 实际接收的 token 比例,$p_i$ 是路由器分配给专家 $i$ 的平均概率,$\alpha$ 是平衡系数。当所有专家的负载均匀时,该损失取最小值。

容量因子(Capacity Factor)

为每个专家设定一个最大接收 token 数。超出容量的 token 会被丢弃(送入残差路径)或重新路由到其他专家。这从硬性约束的角度防止单个专家过载。

Expert Choice 路由

与传统的 token-choose-expert(由 token 选择专家)不同,Expert Choice 让每个专家主动选择要处理的 top-k 个 token。这天然保证每个专家处理相同数量的 token,从根本上解决负载不均衡的问题。

8.4 AI Infra 关联:Expert Parallelism

MoE 引入了一种新的并行维度——专家并行(Expert Parallelism, EP)。由于每个专家是独立的 FFN,可以将不同的专家放置在不同的 GPU 上。

这带来了独特的通信模式:传统的张量并行使用 AllReduce(所有卡都参与),而 MoE 的专家并行使用 All-to-All 通信——每个 GPU 需要把分配给其他 GPU 上专家的 token 发送过去,同时接收分配给本地专家的 token。这种通信模式对网络拓扑和通信库的要求与 AllReduce 截然不同,是 MoE 系统设计中的核心挑战。

Mixtral 8x7B、DeepSeek-V2 等模型的成功部署,背后都需要精心设计的 EP 调度策略和通信优化。


9. CUDA Kernel 融合在 FFN 中的应用

9.1 为什么需要 Kernel 融合

GPU 执行计算的基本单位是 kernel——每次从 CPU 调度一个 kernel 到 GPU 上执行。每次 kernel 启动(launch)都有固定的调度开销(通常几微秒),而且每个独立的 kernel 都需要从 HBM(高带宽显存)读取输入、写回输出。

FFN 的前向传播包含多个步骤:GEMM、激活函数、逐元素乘法等。如果每个步骤都作为独立的 kernel 执行,中间结果就需要反复在 HBM 和寄存器/SRAM 之间搬运。对于 SwiGLU FFN,一次前向传播的逐元素操作包括:

  1. Swish 激活:$\text{gate} = x_gate \cdot \sigma(x_gate)$(涉及 sigmoid 和逐元素乘法)
  2. 门控乘法:$\text{hidden} = \text{gate} \otimes \text{up}$

如果分开执行,gate 向量需要写入 HBM 后再读回来做乘法。融合后,这些操作可以在一个 kernel 内完成——gate 的中间结果驻留在寄存器中,直接与 up 相乘,完全不经过 HBM。

9.2 FFN 中的典型融合场景

场景一:Swish 激活 + 门控乘法融合

Swish(gate) * up 融合为一个 kernel。两个 $(N, d_{ff})$ 的输入产生一个 $(N, d_{ff})$ 的输出,HBM 读写量从 4 次降为 3 次(读 2 次 + 写 1 次,省去了 Swish 中间结果的写入和读回)。

场景二:GEMM + 激活融合(Epilogue Fusion)

在 GEMM kernel 的”收尾阶段”(epilogue)中嵌入激活函数。以 $W_{gate}$ 的 GEMM 为例,标准流程是:

  1. kernel A: $x_gate = x \cdot W_{gate}$,写入 HBM
  2. kernel B: $\text{gate} = \text{Swish}(x_gate)$,从 HBM 读出 $x_gate$ 并写回 gate

融合后,在 GEMM kernel 的最后一步(将结果写回 HBM 之前),直接在寄存器中对每个输出元素施加 Swish,然后写入 HBM。这样只需一个 kernel 和一次 HBM 写入。

cuBLAS 和 CUTLASS 等 GEMM 库都提供了 epilogue fusion 的接口,支持在 GEMM 输出后拼接自定义的逐元素操作。

场景三:$W_{gate}$ 和 $W_{up}$ 的 GEMM 合并

$W_{gate}$ 和 $W_{up}$ 的输入都是同一个 $x$,可以将两个矩阵沿列方向拼接为一个 $(d_{model}, 2 \times d_{ff})$ 的大矩阵,执行一次大 GEMM,然后将输出切分为 gate 和 up 两部分。这将两次 GEMM 合并为一次,减少 kernel launch 开销,同时大 GEMM 的 GPU 利用率通常优于两个小 GEMM。

1
2
3
4
5
6
7
8
# 分开执行(2 次 GEMM launch):
gate = x @ W_gate # (N, d_model) x (d_model, d_ff) = (N, d_ff)
up = x @ W_up # (N, d_model) x (d_model, d_ff) = (N, d_ff)

# 合并执行(1 次 GEMM launch):
W_fused = concat(W_gate, W_up, dim=1) # (d_model, 2*d_ff)
fused_out = x @ W_fused # (N, 2*d_ff)
gate, up = split(fused_out, d_ff, dim=1)

场景四:偏置加法 + 残差加法 + LayerNorm 融合

FFN 的输出通常要经过残差加法和后续的 LayerNorm。将这些逐元素操作融合为一个 kernel,可以将 FFN 输出、残差输入和 LayerNorm 参数一次性读入,在片上完成加法和归一化后写回,避免多次 HBM 往返。

9.3 性能影响

在实际大模型推理中,FFN 的 kernel 融合带来的加速效果因场景而异:

  • Decode 阶段(每次只处理 1 个 token):GEMM 退化为矩阵-向量乘法,计算量小,kernel launch 和 HBM 带宽是主要瓶颈,融合收益显著(可达 20-40% 的加速)
  • Prefill 阶段(处理整段 prompt):GEMM 的计算量占主导,逐元素操作的 HBM 开销占比较小,融合收益相对有限(通常 5-15%)

🎯 自我检验清单

完成本文学习后,检验自己是否真正理解了 FFN 模块:

  • 能说清 FFN 在 Transformer Block 中的角色——为什么 Attention 之后还需要 FFN,FFN 负责什么
  • 能在白板上画出标准 FFN 和 SwiGLU FFN 的结构图,标注每个矩阵的维度和数据流方向
  • 能手算标准 FFN 和 SwiGLU FFN 的参数量,并解释为什么两者参数量几乎相同
  • 能解释为什么 FFN 的中间维度选择 4x(标准)和 8/3x(SwiGLU),以及 LLaMA 中 11008 这个数值的由来
  • 能分别说清 ReLU、GELU、Swish 三种激活函数的数学定义、直觉含义和各自优缺点
  • 能解释 ReLU 的神经元死亡问题为什么会发生,以及 GELU 和 Swish 如何避免这个问题
  • 能解释 SwiGLU 的门控机制——$W_{gate}$ 和 $W_{up}$ 各自的作用,为什么门控比统一激活更有效
  • 能描述 Megatron-LM 对 FFN 做张量并行的切分方式——列切分 $W_{up}$/$W_{gate}$,行切分 $W_{down}$,为什么这样切分只需要一次 AllReduce
  • 能概述 MoE 的基本思想——多个专家 FFN、路由机制、负载均衡,以及 Expert Parallelism 的通信模式
  • 能举出至少两个 FFN 中 CUDA kernel 融合的具体场景

📚 参考资料

论文

教程与博客