Effective batch size remains a primary hyperparameter governing the convergence behavior of Large Language Models. While activation checkpointing and mixed precision allow for larger architectures, they do not inherently solve the limitation of fitting a sufficient number of tokens per batch to ensure stable optimization steps. Gradient accumulation decouples the micro-batch size (limited by GPU VRAM) from the global effective batch size (dictated by convergence requirements).In the context of Fully Sharded Data Parallel (FSDP), gradient accumulation functions differently than in standard Distributed Data Parallel (DDP). In DDP, accumulation is often used to reduce communication overhead by skipping the AllReduce synchronization step for several iterations. In FSDP, particularly when operating with ZeRO-3 sharding strategies, the interaction between accumulation, communication, and memory layout requires a distinct approach to avoid accidental out-of-memory (OOM) errors.The Mechanics of Sharded AccumulationStandard gradient accumulation involves executing the forward and backward passes $N$ times before performing an optimizer step. Mathematically, if $B_{micro}$ is the micro-batch size per GPU, $G$ is the number of GPUs, and $N_{acc}$ is the number of accumulation steps, the effective batch size $B_{eff}$ is:$$B_{eff} = B_{micro} \times G \times N_{acc}$$In a non-sharded setup (DDP), gradients are accumulated in dense, full-model-sized tensors on each device. The synchronization (AllReduce) occurs only once every $N_{acc}$ steps.In FSDP, the gradients are sharded. Each rank is responsible for only a fraction ($1/G$) of the total gradient parameters. The challenge lies in the lifecycle of the gradient. During the backward pass, FSDP computes gradients for a layer. Immediately after computation, a ReduceScatter operation aggregates these gradients across ranks and shards them. The full-sized gradients are then discarded to free memory.If we blindly apply standard accumulation logic, we face a trade-off between communication efficiency and memory usage.The no_sync TrapPyTorch provides a model.no_sync() context manager, which is the standard tool for gradient accumulation in DDP. It prevents the communication hook from firing during the backward pass.However, utilizing no_sync() with FSDP fundamentally alters the memory profile. If communication is disabled, FSDP cannot perform the ReduceScatter operation. Consequently, each rank must hold the unsharded gradients for the entire model until the synchronization occurs.For a 70B parameter model using Float16, the gradients alone require 140GB of memory. If no_sync() is active, every GPU attempts to allocate this 140GB, likely causing an immediate OOM. Therefore, for large models where FSDP is necessary, we typically do not use no_sync(). instead, we allow the ReduceScatter to happen at every micro-step.This means that in FSDP for Large Models:Memory Efficient: We accept the communication overhead of ReduceScatter on every micro-batch.Accumulation Location: The accumulation happens on the sharded gradients, which are $1/G$ the size of the model.Internal Logic of Sharded AccumulationWhen not using no_sync(), the training loop proceeds as follows:Forward Pass: Parameters are gathered (AllGather), computation occurs, parameters are freed.Backward Pass: Parameters are gathered again, gradients computed.Reduce-Scatter: Gradients are synchronized and sharded immediately.Accumulation: The resulting sharded gradient is added to the .grad attribute of the local parameter shard.PyTorch handles the addition to .grad automatically. If .grad is not set to None (via zero_grad()), the autograd engine adds the newly computed gradients to the existing values.The following diagram illustrates the memory footprint comparison between DDP-style accumulation (holding full grads) and FSDP-style accumulation (accumulating sharded grads).digraph G { rankdir=TB; node [shape=box, style=filled, fontname="Helvetica"]; subgraph cluster_ddp { label="DDP / FSDP with no_sync()"; style=dashed; color="#adb5bd"; step1 [label="Micro-Batch 1 Backward", fillcolor="#eebefa"]; mem1 [label="Alloc: Full Unsharded Grads", fillcolor="#ffc9c9"]; step2 [label="Micro-Batch 2 Backward", fillcolor="#eebefa"]; mem2 [label="Alloc: Full Unsharded Grads\n(Accumulated)", fillcolor="#ffc9c9"]; comm [label="Communication (AllReduce/ReduceScatter)", fillcolor="#91a7ff"]; step1 -> mem1 -> step2 -> mem2 -> comm; } subgraph cluster_fsdp { label="FSDP Standard Accumulation"; style=dashed; color="#adb5bd"; f_step1 [label="Micro-Batch 1 Backward", fillcolor="#b2f2bb"]; f_comm1 [label="ReduceScatter", fillcolor="#91a7ff"]; f_mem1 [label="Store: Sharded Grads (1/G)", fillcolor="#d8f5a2"]; f_step2 [label="Micro-Batch 2 Backward", fillcolor="#b2f2bb"]; f_comm2 [label="ReduceScatter", fillcolor="#91a7ff"]; f_mem2 [label="Accumulate into Sharded Grads", fillcolor="#d8f5a2"]; f_step1 -> f_comm1 -> f_mem1 -> f_step2 -> f_comm2 -> f_mem2; } }Comparison of memory states during accumulation. The top path shows the memory spike risk when delaying communication. The bottom path shows the memory-safe FSDP approach where communication occurs per micro-batch to maintain sharded state.Implementation PatternsTo implement gradient accumulation with FSDP, we must control the optimizer step and the gradient zeroing manually. We do not need specialized context managers; we simply rely on the behavior of the PyTorch autograd engine to accumulate into the leaf tensors (the sharded parameters).The following implementation demonstrates a loop for FSDP training. Note the normalization of the loss. Since the optimizer step happens once per $N$ micro-batches, the loss gradients must be scaled by $1/N_{acc}$ to prevent the effective learning rate from scaling with the number of accumulation steps.from torch.distributed.fsdp import FullyShardedDataParallel as FSDP # Hyperparameters accumulation_steps = 4 model = FSDP(base_model, ...) optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) # Training Loop model.train() optimizer.zero_grad(set_to_none=True) for step, (inputs, labels) in enumerate(dataloader): # 1. Forward Pass outputs = model(inputs) loss = loss_fn(outputs, labels) # 2. Scale Loss for Accumulation # This ensures the magnitude of gradients remains consistent # regardless of accumulation steps. loss = loss / accumulation_steps # 3. Backward Pass # FSDP automatically handles ReduceScatter here. # Gradients are accumulated into model.parameters().grad # which are already sharded. loss.backward() # 4. Conditional Step if (step + 1) % accumulation_steps == 0: # Optional: Gradient Clipping # FSDP handles clipping on sharded gradients correctly model.clip_grad_norm_(1.0) # Update weights optimizer.step() # Clear gradients for the next accumulation cycle optimizer.zero_grad(set_to_none=True)Throughput ImplicationsWhile this approach solves the memory constraint, it introduces a performance characteristic specific to FSDP. In DDP, gradient accumulation improves throughput by reducing the frequency of expensive network calls. In FSDP (without no_sync), we perform ReduceScatter on every micro-batch. Therefore, we do not see the same reduction in communication overhead.The primary benefit of gradient accumulation in FSDP is memory viability, not communication hiding. It allows you to train with a target global batch size (e.g., 4 million tokens) even when the hardware can only support a local micro-batch of 1 or 2 sequences due to the memory weight of activations and temporary buffers.However, there is a subtle efficiency gain. By performing the optimizer update less frequently, we reduce the overhead of kernel launches associated with the optimizer step itself (e.g., AdamW logic), and we reduce the frequency of updating the optimizer state shards. For extremely large models where the optimizer step is non-trivial, this provides a measurable throughput increase.Gradient Accumulation with Hybrid Sharding (HSDP)When using Hybrid Sharded Data Parallel (HSDP), where parameters are sharded within a node but replicated across nodes, we regain some flexibility. HSDP performs ReduceScatter within the node and AllReduce across nodes.In this configuration, one might consider accumulating gradients locally before the inter-node AllReduce. However, PyTorch's current FSDP implementation tightly couples the intra-node and inter-node communications. The standard recommendation remains to allow the communication hooks to fire per micro-batch to ensure the memory stays consistently sharded.Optimizing the accumulation phase often involves overlapping the ReduceScatter of the current layer with the computation of the previous layer (Backward Prefetching), which is configured via the backward_prefetch policy discussed in the next chapter. This overlap is critical when accumulation prevents us from reducing the total number of communication calls.