趋近智
训练大型语言模型带来了内存方面的困难,仅靠硬件的线性扩展无法单独解决。使用标准的分布式数据并行 (DDP) 时,训练过程会达到一个由单个 GPU 显存决定的严格上限。明确此内存占用的具体构成是必需的,以便构建能训练数十亿或数万亿参数模型的系统。
为优化内存使用,我们必须先量化训练单个参数的开销。一个常见误解是,拥有 Ψ 个参数的模型需要 4Ψ 字节(假定为 32 位浮点数)或 2Ψ 字节(假定为 16 位浮点数)的内存。实际上,使用 Adam 优化器和混合精度进行训练时的内存占用要大得多。
在标准混合精度训练流程中(计算使用 FP16 或 BF16,权重更新使用 FP32),系统必须维护模型状态的多个副本。对于模型中的每个参数,内存分配包含:
将这些部分加起来得到混合精度训练的内存常数:
Mparam=2+2+4+4+4=16 字节
因此,拥有 Ψ 个参数的模型需要 16Ψ 字节的静态内存。一个 70 亿参数的模型,在现代大型语言模型(LLM)环境下常被认为是“小型”的,仅加载权重和优化器状态就需要 7×109×16 字节,或大约 112 GB 的显存。这在处理任何一个 token 之前就已超出 NVIDIA A100 (80GB) 的容量。
DDP 的运作方式是将整个模型状态复制到集群中的每个 worker 上。如果部署一个包含 N 个 GPU 的集群,DDP 会创建 N 份相同的模型参数、梯度和优化器状态副本。DDP 中的通信步骤(AllReduce)会同步 worker 间的梯度,但它不会减少任何单个设备上的内存占用。
随着模型规模的增长,DDP 的效率会降低。虽然 DDP 允许通过增加 GPU 来扩展批处理大小,但它不允许扩展模型规模。每个 GPU 的内存需求保持不变,无论集群大小如何:
内存DDP=16Ψ+激活内存+碎片
这种架构导致大量的内存冗余。在一个有 16 个 GPU 训练 10 亿参数模型(占用 16 GB 内存)的集群中,总集群内存使用量为 16×16 GB=256 GB。然而,存储的独特信息只有 16 GB。剩下的 240 GB 是重复数据。
混合精度训练中每个参数的内存分配明细。
零冗余优化器 (ZeRO) 解决了这种低效问题,其原理是:虽然所有 GPU 在正向和反向传播期间都需要访问所有权重,但它们无需同时持久化所有权重、梯度和优化器状态。
ZeRO 将模型状态在可用的数据并行进程之间进行分区(分片)。如果有 Nd 个 GPU,ZeRO 会分割数据,使得每个 GPU 拥有总状态的 1/Nd。这种分片可分三个渐进阶段实施,每个阶段都能带来更大的内存节省,代价是通信复杂程度增加。
优化器状态(主权重、动量、方差)构成内存占用的主体(16 字节中的 12 字节)。在阶段 1 中,这些状态在 Nd 个 GPU 之间分片。每个 GPU 只更新其分配到的优化器状态分区。
内存Stage1=2Ψ (权重)+2Ψ (梯度)+Nd12Ψ (优化器状态)
阶段 2 将分片扩展到梯度。梯度在反向传播期间计算出来后,它们会立即被归约和分片,而不是在本地聚合。
内存Stage2=2Ψ (权重)+Nd2Ψ+12Ψ
阶段 3 是 FSDP 的核心。它分片模型参数本身。在此阶段,GPU 只保留模型的一小部分。当计算需要某个特定层时,参数会从其他 GPU 收集过来,使用后立即丢弃以释放内存。
内存Stage3=Nd16Ψ
ZeRO 阶段 3 的理论极限使每个设备的内存占用趋近于零,随着设备数量 Nd 的增加,从而将大部分显存留给激活和更大的批处理大小。
内存效率的差异变得明显,随着我们扩展 GPU 数量。使用 DDP 时,增加 GPU 不会降低每个设备的内存压力。使用 FSDP (ZeRO 阶段 3) 时,内存压力会随着硬件的增加而线性减少。
例如,一个训练拥有 Ψ 参数的大型语言模型的情景。下面的图表显示了随着集群规模的扩大,80GB A100 GPU 上可训练的最大模型规模。
随着集群规模扩展,每个 GPU 可训练的最大模型规模(以十亿参数计)。
在 DDP 配置中(红线),每个 GPU 的最大模型规模被严格限定在大约 35 亿参数(为激活留出缓冲)。增加 60 个 GPU 也不会改变这个限制。在 FSDP 配置中(蓝线),容量线性扩展。拥有 64 个 GPU 时,集群能够有效训练接近 1750 亿参数的模型,因为 16Ψ 的静态状态被稀疏地分布在集群中。
需要注意,ZeRO 只减少 模型状态 的内存占用。它本身不会减少 激活 所需的内存,即那些为反向传播存储的层中间输出。激活内存取决于批处理大小、序列长度和 Transformer 架构(例如,隐藏维度、注意力头)。
尽管 FSDP 从模型参数中释放了大量显存,但训练 TB 级模型通常需要将 FSDP 与 激活检查点 结合(在反向传播期间重新计算激活),以将激活内存保持在限制内。我们将在第 3 章实施这种集成。
通过从 DDP 转向 FSDP,我们从模型架构受限于单设备限制的模式,转移到模型规模仅受总集群容量和网络带宽限制的模式。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造