Scaling memory capacity via sharding introduces a necessary trade-off: increased network utilization. While ZeRO stages effectively dismantle the memory wall by partitioning states across devices, they fundamentally alter the communication pattern of the training loop. In a standard Distributed Data Parallel (DDP) setup, communication is confined exclusively to the backward pass where gradients are synchronized. In contrast, Fully Sharded Data Parallel (FSDP) requires active communication during both the forward and backward passes to materialize parameters on demand.
Understanding the precise volume of data movement is critical for designing cluster topologies and debugging performance regression. If the network interconnect cannot sustain the required throughput, the GPU compute units will stall, idling while waiting for parameter shards to arrive.
To analyze the bandwidth requirements, we must first identify the specific collective communication primitives used by ZeRO. Unlike DDP, which primarily relies on AllReduce, FSDP utilizes AllGather and ReduceScatter.
Let Ψ represent the total number of model parameters. We assume mixed-precision training where parameters and gradients are stored in 16-bit formats (FP16 or BF16), meaning each element occupies 2 bytes.
In a cluster with N GPUs, a standard DDP implementation performs an AllReduce operation on the gradients once per step. Using ring-based or tree-based algorithms, the communication volume for AllReduce is 2Ψ. This effectively means every parameter element traverses the network twice (once for reduction, once for broadcast) per training step, independent of N (for large N).
FSDP changes this equation. Since parameters are sharded, they are not locally available. The training step follows this sequence:
AllGather is triggered to collect parameters. Each rank downloads NN−1Ψ data.AllGathered again to compute gradients with respect to the input.ReduceScatter.The total data movement for FSDP (specifically ZeRO-3) per step is the sum of these three operations. For large N, the cost of both AllGather and ReduceScatter approaches Ψ.
This analysis reveals a significant architectural implication: FSDP requires approximately 1.5x the communication bandwidth of DDP (3Ψ vs 2Ψ).
The data flow demonstrates the ephemeral nature of full parameters in FSDP. Unlike DDP, where parameters persist, FSDP requires network traversal to materialize weights for both forward and backward passes, followed by a scatter operation for gradients.
The formula VFSDP≈3Ψ is an approximation that assumes high N. The precise communication cost for a ring-based collective is 2NN−1Ψ for AllReduce and NN−1Ψ for AllGather/ReduceScatter.
As the cluster size N increases, the term NN−1 rapidly approaches 1. This indicates that adding more GPUs does not reduce the per-GPU communication volume. Instead, it keeps the per-GPU throughput requirement constant while increasing the total aggregate bandwidth of the cluster.
This constant per-device bandwidth pressure makes FSDP sensitive to "stragglers" or slow nodes. If a single link in the cluster negotiates a lower speed (e.g., a fallback from InfiniBand HDR to EDR, or a TCP retransmission issue), the AllGather collective for the entire group operates at the speed of that slowest link.
The following chart compares the theoretical minimum data transfer required per training step for various model sizes. Note the divergence between DDP and FSDP as model size grows.
Comparison of network volume for 16-bit mixed precision training. As model size scales to 70B parameters, FSDP requires moving 420GB of data across the network interconnect for every single optimizer step, strictly for parameter and gradient synchronization.
While bandwidth (GB/s) determines the transfer time for large tensors, latency (microseconds) dictates performance for small tensors. A naive implementation of FSDP could attempt to shard every single linear layer individually. For a Transformer block containing distinct Key, Query, and Value projection layers, this would trigger three separate small AllGather operations.
The overhead of initiating a collective communication kernel often outweighs the transfer time for small payloads. To mitigate this, PyTorch FSDP aggregates parameters into "FlatParameters." This mechanism flattens multiple small tensors within a module (like a Transformer Decoder Layer) into a single contiguous block of memory.
This aggregation serves two purposes:
AllGather to operate on larger chunks of data, better utilizing the interconnect bandwidth.If the communication time for 3Ψ exceeds the computation time for the forward/backward pass, the training becomes communication-bound. This ratio is the primary efficiency metric when tuning large clusters. Efficient scaling requires that the network interconnect provides enough bandwidth such that:
Tcomm≤TcomputeWhen this condition is met, and overlap strategies are correctly configured, the cost of the 3Ψ data movement can be effectively hidden behind the arithmetic intensity of the matrix multiplications.
Was this section helpful?
© 2026 ApX Machine LearningEngineered with