现代大语言模型(如LLaMA)中常使用的 RMSNorm 和传统的 LayerNorm 有什么区别?
在现代大语言模型(如 LLaMA、Mistral、Gemma 等)中,RMSNorm(Root Mean Square Normalization,均方根归一化) 已经基本上取代了传统的 LayerNorm(层归一化)。
这两者的核心区别在于:RMSNorm 是 LayerNorm 的一种“简化加速版”。它去除了 LayerNorm 中的均值计算(Mean-centering),从而在保持模型表现几乎不变的前提下,显著提升了计算效率。
以下是它们之间详细的对比和原理解析:
1. 数学公式上的区别
传统的 LayerNorm
LayerNorm 需要同时计算输入向量 的均值(Mean)和方差(Variance),并将输入进行平移和缩放。
公式如下:
- 计算均值:
- 计算方差:
- 归一化并仿射变换:
(其中, 是向量维度, 是防止分母为0的极小值, 和 是模型可学习的缩放和平移参数。)
现代的 RMSNorm
RMSNorm 认为,LayerNorm 中真正起作用的是缩放(Scaling),而不是平移(Shifting/Mean-centering)。因此,它直接假设均值 ,只计算输入的均方根(RMS)。
公式如下:
- 计算均方根:
- 归一化并缩放:
(注意:RMSNorm 中没有减去均值 的步骤,也去除了可学习的偏置参数 。)
2. 核心差异总结
| 特性 | LayerNorm | RMSNorm |
|---|---|---|
| 均值中心化 (Mean Centering) | 有(减去均值 ) | 无(直接使用原始 ) |
| 可学习参数 | 缩放因子 + 偏置项 | 仅有缩放因子 |
| 不变性 (Invariance) | 缩放不变性 + 平移不变性 | 仅有缩放不变性 |
| 计算复杂度 | 较高(需计算均值,再计算基于均值的方差) | 较低(只需计算平方和的均值) |
| 内存与参数开销 | 略多(需存储 和 ) | 更少(仅存储 ) |
3. 为什么 LLaMA 等大模型偏爱 RMSNorm?
A. 计算效率更高(最重要的原因)
在 LayerNorm 中,为了计算方差,必须先算出均值 ,这就需要在内存/显存中对数据进行两次遍历(或使用更复杂的并行规约算法)。
而在 RMSNorm 中,只需要一次遍历计算 的总和即可。根据原论文的测试,RMSNorm 的计算速度比 LayerNorm 快 10% 到 50%。对于拥有几十到上百层 Transformer、处理超长上下文的 LLM 来说,这种底层算子的加速能累积成巨大的性能提升。
B. 性能(准确率)几乎无损
RMSNorm 的提出者(Biao Zhang 和 Rico Sennrich, 2019)通过实验发现,LayerNorm 中的均值中心化操作对 Transformer 模型的训练稳定性及最终性能并没有决定性的影响。真正防止梯度爆炸/消失、稳定训练的核心在于动态缩放(即除以标准差或均方根)。既然去掉了均值也不影响准确率,自然选择更简单的方案。
C. 更适合现代硬件的内存带宽
大模型的推理(Inference)往往是内存带宽受限(Memory Bandwidth Bound)的。RMSNorm 少了一个偏置参数 ,并且减少了计算过程中的中间变量读写(如 ),这极其符合 GPU/TPU 等现代硬件对计算效率的极致追求。
总结
RMSNorm 就是“去掉减去均值步骤”的 LayerNorm。 现代大语言模型(如 LLaMA)使用 RMSNorm 是工程与数学结合的经典案例:通过放弃数学上看似严谨但在深度网络中并不必要的平移不变性,换取了极其宝贵的计算速度和内存效率,同时维持了极高的模型能力。