ZeRO (Zero Redundancy Optimizer) 是 DeepSpeed(以及后来的 PyTorch FSDP)的核心技术,旨在解决大模型训练中的显存瓶颈。它通过在数据并行(Data Parallelism)的各个 GPU 之间去除冗余数据来节省显存。 在大模型训练中,显存主要被三类数据占用:优化器状态 (Optimizer States)、梯度 (Gradients) 和 模型参数 (Parameters)。 ZeRO 的三个阶段分别切分(Shard/Partition)了以下内容: 简短总结 Stage 1: 仅切分 优化器状态。 Stage 2: 切分 优化器状态 + 梯度。 Stage 3: 切分 优化器状态 + 梯度 + 模型参数。 --- 详细解析 为了理解这三个阶段,假设我们使用 Adam 优化器 和 混合精度训练 (Mixed Precision),模型参数量为 $\Phi$。 1. ZeRO Stage 1:切分优化器状态 (Partitioning Optimizer States) 切分内容:仅将 Optimizer States(如 Ada...