有效批大小仍是决定大语言模型收敛表现的一个主要超参数。虽然激活检查点和混合精度可以支持更大的模型结构,但它们本身不能解决每个批次容纳足够令牌数的限制,以保证优化步骤的稳定性。梯度累积将微批大小(受限于GPU显存)与全局有效批大小(由收敛要求决定)解耦开来。对于完全分片数据并行(FSDP),梯度累积的运作方式与标准分布式数据并行(DDP)不同。在DDP中,累积常用于通过跳过若干次迭代的AllReduce同步步骤来减少通信开销。在FSDP中,特别是在使用ZeRO-3分片策略时,累积、通信和内存分布方式的关系需要一种不同的方法,以避免意外的内存溢出(OOM)错误。分片累积的运作原理标准梯度累积需要在执行优化器步骤前进行$N$次正向和反向传播。从数学上看,如果$B_{micro}$是每个GPU的微批大小,$G$是GPU数量,$N_{acc}$是累积步数,那么有效批大小$B_{eff}$为:$$B_{eff} = B_{micro} \times G \times N_{acc}$$在非分片设置(DDP)中,梯度被累积在每个设备上密集且完整模型大小的张量中。同步操作(AllReduce)每$N_{acc}$步才进行一次。在FSDP中,梯度是分片的。每个进程只负责总梯度参数的一小部分($1/G$)。问题在于梯度的生命周期。在反向传播过程中,FSDP计算层的梯度。计算完成后,一个ReduceScatter操作会将这些梯度在不同进程间聚合并进行分片。然后,完整大小的梯度会被丢弃以释放内存。如果我们盲目应用标准累积逻辑,我们将面临通信效率和内存使用之间的权衡。no_sync的注意事项PyTorch提供了一个model.no_sync()上下文管理器,这是DDP中梯度累积的标准工具。它会阻止通信钩子在反向传播过程中触发。然而,在FSDP中使用no_sync()会从根本上改变内存占用情况。如果通信被禁用,FSDP无法执行ReduceScatter操作。因此,每个进程必须保留整个模型未分片的梯度,直到同步发生。对于使用Float16的700亿参数模型,仅梯度就需要140GB内存。如果no_sync()被启用,每个GPU都尝试分配这140GB,很可能立即导致内存溢出(OOM)。所以,对于需要FSDP的大模型,我们通常不使用no_sync()。相反,我们允许ReduceScatter在每个微步骤中发生。这意味着在大模型FSDP中:内存效率高: 我们接受每个微批次ReduceScatter带来的通信开销。累积位置: 累积发生在分片的梯度上,这些梯度是模型大小的$1/G$。分片累积的内部逻辑当不使用no_sync()时,训练循环按以下步骤进行:正向传播: 参数被收集(AllGather),计算发生,参数被释放。反向传播: 参数再次被收集,梯度被计算。Reduce-Scatter: 梯度立即被同步和分片。累积: 生成的分片梯度被添加到本地参数分片的.grad属性中。PyTorch会自动处理向.grad的添加。如果.grad没有被设置为None(通过zero_grad()),自动梯度引擎会将新计算的梯度添加到现有值中。下图说明了内存占用情况的比较,比较了DDP风格累积(持有完整梯度)和FSDP风格累积(累积分片梯度)。digraph G { rankdir=TB; node [shape=box, style=filled, fontname="Helvetica"]; subgraph cluster_ddp { label="DDP / FSDP (使用 no_sync())"; style=dashed; color="#adb5bd"; step1 [label="微批次1反向传播", fillcolor="#eebefa"]; mem1 [label="分配:完整未分片梯度", fillcolor="#ffc9c9"]; step2 [label="微批次2反向传播", fillcolor="#eebefa"]; mem2 [label="分配:完整未分片梯度\n(已累积)", fillcolor="#ffc9c9"]; comm [label="通信 (AllReduce/ReduceScatter)", fillcolor="#91a7ff"]; step1 -> mem1 -> step2 -> mem2 -> comm; } subgraph cluster_fsdp { label="FSDP标准累积"; style=dashed; color="#adb5bd"; f_step1 [label="微批次1反向传播", fillcolor="#b2f2bb"]; f_comm1 [label="ReduceScatter", fillcolor="#91a7ff"]; f_mem1 [label="存储:分片梯度 (1/G)", fillcolor="#d8f5a2"]; f_step2 [label="微批次2反向传播", fillcolor="#b2f2bb"]; f_comm2 [label="ReduceScatter", fillcolor="#91a7ff"]; f_mem2 [label="累积到分片梯度中", fillcolor="#d8f5a2"]; f_step1 -> f_comm1 -> f_mem1 -> f_step2 -> f_comm2 -> f_mem2; } }累积过程中内存状态的比较。上方路径显示了延迟通信时的内存峰值风险。下方路径显示了内存安全的FSDP方法,其中通信在每个微批次发生,以保持分片状态。实现方式要在FSDP中实现梯度累积,我们必须手动控制优化器步骤和梯度清零。我们不需要专门的上下文管理器;我们只需依靠PyTorch自动梯度引擎的机制将梯度累积到叶张量(即分片参数)中。以下实现展示了一个FSDP训练循环。请注意损失的归一化处理。因为优化器步骤每$N$个微批次才发生一次,损失梯度必须按$1/N_{acc}$进行缩放,以避免有效学习率随累积步数而缩放。from torch.distributed.fsdp import FullyShardedDataParallel as FSDP # 超参数 accumulation_steps = 4 model = FSDP(base_model, ...) optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) # 训练循环 model.train() optimizer.zero_grad(set_to_none=True) for step, (inputs, labels) in enumerate(dataloader): # 1. 正向传播 outputs = model(inputs) loss = loss_fn(outputs, labels) # 2. 为累积缩放损失 # 这可以确保梯度的量级保持一致 # 无论累积步数如何。 loss = loss / accumulation_steps # 3. 反向传播 # FSDP在此处自动处理ReduceScatter。 # 梯度被累积到model.parameters().grad中 # 它们已经被分片。 loss.backward() # 4. 条件步骤 if (step + 1) % accumulation_steps == 0: # 可选:梯度裁剪 # FSDP正确处理分片梯度的裁剪 model.clip_grad_norm_(1.0) # 更新权重 optimizer.step() # 清除梯度以进行下一个累积周期 optimizer.zero_grad(set_to_none=True)吞吐量影响这种方法解决了内存限制,但也引入了FSDP特有的性能表现。在DDP中,梯度累积通过减少昂贵网络调用的频率来提高吞吐量。在FSDP中(不使用no_sync),我们在每个微批次上执行ReduceScatter。所以,我们看不到通信开销的同等减少。FSDP中梯度累积的主要益处是内存可行性,而非隐藏通信。它允许您使用目标全局批大小(例如,400万个令牌)进行训练,即使硬件由于激活和临时缓冲区的内存占用,只能支持1或2个序列的本地微批次。然而,仍然有一些效率提升。通过减少优化器更新的频率,我们减少了与优化器步骤本身相关的内核启动开销(例如,AdamW逻辑),并降低了更新优化器状态分片的频率。对于非常大的模型,其中优化器步骤的开销不容忽视,这会带来可衡量的吞吐量提高。混合分片(HSDP)下的梯度累积当使用混合分片数据并行(HSDP)时,其中参数在节点内分片但在节点间复制,我们重新获得一些灵活性。HSDP在节点内执行ReduceScatter,在节点间执行AllReduce。在这种配置下,人们可能会考虑在节点间AllReduce之前在本地累积梯度。然而,PyTorch当前的FSDP实现紧密关联了节点内和节点间的通信。标准建议仍然是允许通信钩子在每个微批次触发,以确保内存持续处于分片状态。优化累积阶段通常需要将当前层的ReduceScatter与前一层的计算(反向预取)重叠进行,这通过下一章讨论的backward_prefetch策略进行配置。当累积阻止我们减少通信调用总数时,这种重叠很要紧。