趋近智
有效批大小仍是决定大语言模型收敛表现的一个主要超参数。虽然激活检查点和混合精度可以支持更大的模型结构,但它们本身不能解决每个批次容纳足够令牌数的限制,以保证优化步骤的稳定性。梯度累积将微批大小(受限于GPU显存)与全局有效批大小(由收敛要求决定)解耦开来。
对于完全分片数据并行(FSDP),梯度累积的运作方式与标准分布式数据并行(DDP)不同。在DDP中,累积常用于通过跳过若干次迭代的AllReduce同步步骤来减少通信开销。在FSDP中,特别是在使用ZeRO-3分片策略时,累积、通信和内存分布方式的关系需要一种不同的方法,以避免意外的内存溢出(OOM)错误。
标准梯度累积需要在执行优化器步骤前进行N次正向和反向传播。从数学上看,如果Bmicro是每个GPU的微批大小,G是GPU数量,Nacc是累积步数,那么有效批大小Beff为:
Beff=Bmicro×G×Nacc
在非分片设置(DDP)中,梯度被累积在每个设备上密集且完整模型大小的张量中。同步操作(AllReduce)每Nacc步才进行一次。
在FSDP中,梯度是分片的。每个进程只负责总梯度参数的一小部分(1/G)。问题在于梯度的生命周期。在反向传播过程中,FSDP计算层的梯度。计算完成后,一个ReduceScatter操作会将这些梯度在不同进程间聚合并进行分片。然后,完整大小的梯度会被丢弃以释放内存。
如果我们盲目应用标准累积逻辑,我们将面临通信效率和内存使用之间的权衡。
no_sync的注意事项PyTorch提供了一个model.no_sync()上下文管理器,这是DDP中梯度累积的标准工具。它会阻止通信钩子在反向传播过程中触发。
然而,在FSDP中使用no_sync()会从根本上改变内存占用情况。如果通信被禁用,FSDP无法执行ReduceScatter操作。因此,每个进程必须保留整个模型未分片的梯度,直到同步发生。
对于使用Float16的700亿参数模型,仅梯度就需要140GB内存。如果no_sync()被启用,每个GPU都尝试分配这140GB,很可能立即导致内存溢出(OOM)。所以,对于需要FSDP的大模型,我们通常不使用no_sync()。相反,我们允许ReduceScatter在每个微步骤中发生。
这意味着在大模型FSDP中:
ReduceScatter带来的通信开销。当不使用no_sync()时,训练循环按以下步骤进行:
AllGather),计算发生,参数被释放。.grad属性中。PyTorch会自动处理向.grad的添加。如果.grad没有被设置为None(通过zero_grad()),自动梯度引擎会将新计算的梯度添加到现有值中。
下图说明了内存占用情况的比较,比较了DDP风格累积(持有完整梯度)和FSDP风格累积(累积分片梯度)。
累积过程中内存状态的比较。上方路径显示了延迟通信时的内存峰值风险。下方路径显示了内存安全的FSDP方法,其中通信在每个微批次发生,以保持分片状态。
要在FSDP中实现梯度累积,我们必须手动控制优化器步骤和梯度清零。我们不需要专门的上下文管理器;我们只需依靠PyTorch自动梯度引擎的机制将梯度累积到叶张量(即分片参数)中。
以下实现展示了一个FSDP训练循环。请注意损失的归一化处理。因为优化器步骤每N个微批次才发生一次,损失梯度必须按1/Nacc进行缩放,以避免有效学习率随累积步数而缩放。
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
# 超参数
accumulation_steps = 4
model = FSDP(base_model, ...)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
# 训练循环
model.train()
optimizer.zero_grad(set_to_none=True)
for step, (inputs, labels) in enumerate(dataloader):
# 1. 正向传播
outputs = model(inputs)
loss = loss_fn(outputs, labels)
# 2. 为累积缩放损失
# 这可以确保梯度的量级保持一致
# 无论累积步数如何。
loss = loss / accumulation_steps
# 3. 反向传播
# FSDP在此处自动处理ReduceScatter。
# 梯度被累积到model.parameters().grad中
# 它们已经被分片。
loss.backward()
# 4. 条件步骤
if (step + 1) % accumulation_steps == 0:
# 可选:梯度裁剪
# FSDP正确处理分片梯度的裁剪
model.clip_grad_norm_(1.0)
# 更新权重
optimizer.step()
# 清除梯度以进行下一个累积周期
optimizer.zero_grad(set_to_none=True)
这种方法解决了内存限制,但也引入了FSDP特有的性能表现。在DDP中,梯度累积通过减少昂贵网络调用的频率来提高吞吐量。在FSDP中(不使用no_sync),我们在每个微批次上执行ReduceScatter。所以,我们看不到通信开销的同等减少。
FSDP中梯度累积的主要益处是内存可行性,而非隐藏通信。它允许您使用目标全局批大小(例如,400万个令牌)进行训练,即使硬件由于激活和临时缓冲区的内存占用,只能支持1或2个序列的本地微批次。
然而,仍然有一些效率提升。通过减少优化器更新的频率,我们减少了与优化器步骤本身相关的内核启动开销(例如,AdamW逻辑),并降低了更新优化器状态分片的频率。对于非常大的模型,其中优化器步骤的开销不容忽视,这会带来可衡量的吞吐量提高。
当使用混合分片数据并行(HSDP)时,其中参数在节点内分片但在节点间复制,我们重新获得一些灵活性。HSDP在节点内执行ReduceScatter,在节点间执行AllReduce。
在这种配置下,人们可能会考虑在节点间AllReduce之前在本地累积梯度。然而,PyTorch当前的FSDP实现紧密关联了节点内和节点间的通信。标准建议仍然是允许通信钩子在每个微批次触发,以确保内存持续处于分片状态。
优化累积阶段通常需要将当前层的ReduceScatter与前一层的计算(反向预取)重叠进行,这通过下一章讨论的backward_prefetch策略进行配置。当累积阻止我们减少通信调用总数时,这种重叠很要紧。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造