为什么Transformer中优先使用 Layer Normalization (LN) 而不是 Batch Normalization (BN)?
在 Transformer 及其他处理序列数据的模型(如 RNN)中,几乎总是优先选择 Layer Normalization (LN) 而不是 Batch Normalization (BN)。
这主要是由文本/序列数据的天然特性以及模型训练的实际限制决定的。我们可以从输入张量(Tensor)的维度切入,深入理解其背后的原因。
在 Transformer 中,输入张量的维度通常是 [N, L, D]:
- N (Batch Size):批次大小(包含多少个句子)
- L (Sequence Length):序列长度(每个句子包含多少个词/Token)
- D (Hidden Dimension):特征/隐藏层维度(每个词的词向量维度)
以下是为什么选择 LN 而不用 BN 的核心原因:
1. 应对变长序列(Variable Sequence Length)
这是最关键的原因。在自然语言处理中,一个 Batch 内的句子长度往往参差不齐,通常需要用 Padding(填充词,如 <PAD>)将它们补齐到相同长度。
- 如果用 BN:BN 是在 Batch 维度(N)上进行归一化。
- 如果对每个位置计算均值和方差,对于较长的句子,其尾部位置(比如第 500 个词)可能在整个 Batch 中只有一两个真实词,其余全是 Padding。这会导致均值和方差的计算极其不稳定,甚至因为样本数不足导致除以 0。
- 如果把 N 和 L 展平一起做 BN,大量的 Padding token 参与统计,会严重污染真实数据的分布特征。
- 用 LN:LN 是在特征维度(D)上进行归一化(即对每一个词的向量单独做归一化)。它只关注当前这一个词的词向量,完全不受句子长度和其他句子的影响,因此完美避开了变长序列和 Padding 带来的统计污染问题。
2. 对 Batch Size 极其敏感 (Batch Size Dependency)
- BN 的痛点:BN 依赖于较大的 Batch Size 才能统计出具备全局代表性的均值和方差。如果 Batch Size 太小(如 2 或 4),BN 的统计量波动极大,导致模型难以收敛。随着 Transformer 模型(尤其是大语言模型 LLM)越来越大,受限于显存,实际训练时的单卡 Batch Size 往往非常小(甚至为 1)。在这种情况下,BN 完全失效。
- LN 的优势:LN 的计算发生在单一的一个样本、单一的一个 Token 内部,完全不依赖于 Batch Size。无论 Batch Size 是 1 还是 1024,LN 的表现都高度一致且稳定。
3. 训练与推理(推理阶段)的分布差异
- BN 的痛点:BN 在训练时使用当前 Batch 的统计量,而在推理(测试)时,使用的是训练期间累积的全局滑动平均(Moving Average)均值和方差。在序列生成任务中(如 GPT 的自回归生成),推理是逐字生成的,动态长度的变化使得训练时累积的统计量很难完美匹配推理时的实际分布,从而引发误差。
- LN 的优势:LN 在训练和推理时的计算逻辑完全一致。每次输入一个词,它就立刻根据该词自身的 D 个维度计算均值和方差进行归一化,不需要保存和使用历史滑动平均值。
4. 物理语义的差异 (CV 与 NLP 数据的不同)
- 在计算机视觉 (CV) 中:图像数据的维度是
[N, C, H, W](通道数 C 对应特征维度)。同一个通道(如提取边缘的通道)在不同图像(N)中具有相似的物理意义。因此,在 N 维度上把所有的边缘特征拿出来一起做 BN 是非常合理的。 - 在自然语言处理 (NLP) 中:一个词的语义是由其词向量(D 维)作为一个整体来表达的。一个词向量内部的各个维度之间有着紧密的关联。LN 对这 D 个维度进行归一化,相当于对这个词的特征表达进行了一次整体的缩放和平移,这在语义空间中更有意义,能更好地稳定单个词的特征表示。
总结对比
| 特性 | Batch Normalization (BN) | Layer Normalization (LN) |
|---|---|---|
| 归一化方向 | 跨 Batch (N) 和 序列 (L),针对每个 特征 (D) | 跨 特征 (D),针对每个 样本 (N) 的每个词 (L) |
| 对变长序列的支持 | 差(受 Padding 污染,末端词统计不稳定) | 完美(每个词独立计算) |
| 对 Batch Size 的依赖 | 强(小 Batch 下失效) | 无(Batch Size 为 1 也能工作) |
| 训练与推理一致性 | 差(依赖滑动平均,自回归时易出问题) | 完美(行为完全一致) |
| 适用领域 | 计算机视觉 (CNN) | 序列模型 (Transformer, RNN) |
因此,在提出 Transformer 架构及其后续的所有大模型(BERT, GPT 等)中,Layer Normalization 及其变体(如 RMSNorm)成为了绝对的标准配置。