The limitation of Distributed Data Parallel (DDP) is in its fundamental redundancy. In a DDP setup with N GPUs, the system maintains N identical copies of the model parameters, gradients, and optimizer states. While this allows for parallel computation of the backward pass, it creates a memory wall: the maximum model size is strictly capped by the VRAM of a single GPU, regardless of the total cluster capacity.
To break this barrier, we employ the Zero Redundancy Optimizer (ZeRO) algorithmic strategies. ZeRO eliminates this redundancy by partitioning the model state across data parallel processes. Instead of replicating the full state, each device owns a distinct shard of the data. The optimization occurs in three progressive stages, each trading increased communication complexity for substantial memory savings.
Before partitioning, we must quantify what consumes GPU memory. For a model with Ψ parameters trained with mixed precision (FP16/BF16) and the Adam optimizer, the memory footprint is dominated by three components:
In a standard DDP configuration, every GPU holds all 16Ψ bytes. ZeRO targets these components sequentially.
Stage 1 (Pos) targets the largest memory consumer: the optimizer states. In this configuration, the parameters (P) and gradients (G) remain replicated across all devices, preserving the communication pattern of DDP for the forward and backward passes. However, the optimizer step is sharded.
If you have Nd devices, the optimizer state is split into Nd equal partitions. Each i-th device updates only its specific shard of the parameters. At the end of the step, an AllGather operation synchronizes the updated parameters across all devices.
Memory consumption per device drops from 2Ψ+2Ψ+12Ψ to approximately:
MemStage1=2Ψ+2Ψ+Nd12ΨFor large clusters, this reduces memory usage by nearly 75% compared to DDP, as the optimizer state term approaches zero.
Stage 2 (Pos+g) extends sharding to the gradients. In standard DDP, gradients are computed locally and then synchronized using an AllReduce operation. AllReduce is logically equivalent to a ReduceScatter followed by an AllGather.
ZeRO Stage 2 modifies this flow. After the backward pass, the system performs a ReduceScatter operation. Each GPU receives and aggregates only the gradients corresponding to the partition of the parameters it is responsible for updating. It then discards the rest.
Because the optimizer states are already sharded (from Stage 1), each GPU now has exactly what it needs to update its specific parameter shard: the specific optimizer state and the specific accumulated gradients.
Memory consumption becomes:
MemStage2=2Ψ+Nd2Ψ+Nd12ΨThis stage yields significant gains with minimal communication overhead, as ReduceScatter is a primitive already inherent in the AllReduce operation used by DDP.
Stage 3 (Pos+g+p) is the core mechanism behind what is colloquially termed "Full" FSDP. In this stage, the model parameters themselves are sharded. No single GPU holds the complete model weights at rest.
This introduces a new challenge: computing the forward and backward passes requires the full weights for the specific layers being processed. ZeRO-3 solves this through temporal materialization.
AllGather to fetch the missing parameter shards from other GPUs. The layer computes, and the parameters are immediately freed (discarded) to save memory.AllGathers the full parameters to compute gradients, then discards them.This approach allows training models that are larger than the aggregate memory of the entire cluster is not possible, but it allows training models that are as large as the sum of all GPU memory, minus activation overheads.
The memory per device is reduced to the theoretical minimum:
MemStage3=Nd2Ψ+2Ψ+12Ψ=Nd16ΨThe following visualization demonstrates the memory distribution across 4 devices under different strategies. Note how Stage 3 (FSDP) distributes the entire stack evenly.
Comparison of state allocation between DDP and ZeRO-3 (FSDP) on a 4-GPU cluster. DDP replicates the full state; FSDP shards all components.
The choice of stage dramatically shifts the maximum trainable model size. While Stage 1 and 2 offer substantial reductions, Stage 3 enables linear scaling with the number of GPUs.
The chart below illustrates the memory consumption per GPU for a theoretical 10-billion parameter model (requiring approx 160GB total state) across an 8-GPU cluster.
Memory usage breakdown per GPU. Note that ZeRO-1 provides the largest single drop in memory usage by sharding the optimizer state, while ZeRO-3 minimizes the footprint of all components.
These memory benefits are not free; the currency is network bandwidth.
AllReduce on gradients. Communication volume is 2Ψ per step (send + receive).AllGather on parameters (forward pass), AllGather on parameters (backward pass), and ReduceScatter on gradients. The total communication volume increases to approximately 3Ψ.In bandwidth-restricted environments (like standard Ethernet), the extra latency of fetching parameters on demand in Stage 3 can throttle compute throughput. This makes the configuration of high-speed interconnects like NVLink or InfiniBand essential for Stage 3 training, a topic we will address in the Multi-Node Networking chapter.
In PyTorch FSDP, these strategies are not always mutually exclusive rigid modes but are configured via the sharding_strategy parameter.
ShardingStrategy.FULL_SHARD maps to ZeRO-3.ShardingStrategy.SHARD_GRAD_OP maps to ZeRO-2.ShardingStrategy.NO_SHARD behaves like DDP.Understanding these stages allows you to select the correct strategy based on your hardware constraints. If your model fits in memory with ZeRO-2, it is often preferred over ZeRO-3 due to lower communication overhead. However, for the terabyte-scale models that define modern AI, ZeRO-3 is often the only viable path forward.
Was this section helpful?
© 2026 ApX Machine LearningEngineered with