While standard data parallelism, as implemented in PyTorch's DistributedDataParallel (DDP), is effective for scaling throughput, it introduces a significant memory bottleneck. Each GPU in the data-parallel group maintains a full replica of the model's parameters, gradients, and optimizer states. For large models, especially those using optimizers like Adam which store momentum and variance, the memory required for optimizer states alone can exceed the model's parameter size by a factor of two or more. This replication becomes the limiting factor long before the compute capacity of the GPU is saturated.
Microsoft's DeepSpeed library directly addresses this memory inefficiency with its Zero Redundancy Optimizer, or ZeRO. Instead of replicating the entire training state, ZeRO partitions it across the available data-parallel devices. This allows you to train models that are orders of magnitude larger than what would fit on a single GPU, without resorting to the complexities of pure model or pipeline parallelism.
To fully appreciate what ZeRO accomplishes, let's break down the memory consumption on a single GPU during standard data-parallel training. The total memory, M, can be modeled as the sum of three primary components:
M=MP+MG+MOWhere:
In a standard DDP setup with K GPUs, each GPU holds a full copy of these three components, leading to massive redundancy.
Memory layout in standard data parallelism, where each GPU holds a complete replica of the model parameters, gradients, and optimizer states.
ZeRO systematically eliminates this redundancy by sharding the model and optimizer states across the GPUs. It is implemented in three progressive stages, allowing you to choose the level of optimization that best fits your needs.
This first stage focuses on the most significant source of memory redundancy for optimizers like Adam: the optimizer states. ZeRO-1 partitions the optimizer states across the data-parallel processes. Each GPU is now responsible for updating only its assigned partition of the parameters. During the optimizer step, an all-gather operation collects all the updated parameter shards to ensure each GPU has a complete, up-to-date model for the next forward pass.
ZeRO-2 builds on Stage 1 by also partitioning the gradients. During the backward pass, instead of each GPU holding a full set of gradients and then averaging them with an all-reduce operation, a reduce-scatter operation is used. This operation simultaneously computes the average and scatters the sharded results to the appropriate GPUs. This avoids the transient memory peak of storing both the full gradients and the parameters.
This is the most advanced stage and the one that enables training truly massive models. ZeRO-3 partitions everything: optimizer states, gradients, and the model parameters themselves. Each GPU holds only a slice of the entire model.
During the forward and backward passes, ZeRO-3 dynamically reconstructs the full layers of the model on each GPU only when they are needed for computation. A layer's parameters are gathered from all participating GPUs right before it's executed, and the memory is released immediately after. This means the peak memory usage at any given time is proportional to the size of a single layer, not the entire model.
Memory layout with ZeRO-3, where parameters (P), gradients (G), and optimizer states (O) are all partitioned across the available GPUs.
Even with ZeRO-3, the aggregate size of a massive model's state can exceed the total available GPU memory across your cluster. ZeRO-Offload extends the partitioning hierarchy by moving certain components of the training state to more abundant, albeit slower, memory tiers.
This offloading capability democratizes large model training, making it feasible on systems with limited GPU VRAM but large amounts of system RAM or fast storage.
Integrating DeepSpeed into a PyTorch training script is remarkably straightforward. The primary changes involve initializing the DeepSpeed engine and modifying the training loop to use DeepSpeed's methods.
First, you create a ds_config.json file. This file is the control panel for all DeepSpeed features.
Example ds_config.json for ZeRO-2 with CPU Offload:
{
"train_batch_size": 16,
"train_micro_batch_size_per_gpu": 2,
"steps_per_print": 100,
"optimizer": {
"type": "AdamW",
"params": {
"lr": 0.0001,
"betas": [0.9, 0.999],
"eps": 1e-8,
"weight_decay": 3e-7
}
},
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"contiguous_gradients": true,
"overlap_comm": true
},
"fp16": {
"enabled": true
}
}
Next, you modify your training script. The focus is the deepspeed.initialize function, which wraps your model and optimizer.
import torch
import deepspeed
# Assume model and optimizer are already defined
# model = MyTransformerModel()
# optimizer = torch.optim.AdamW(model.parameters())
# DeepSpeed initialization
model_engine, optimizer, _, _ = deepspeed.initialize(
model=model,
optimizer=optimizer,
config='ds_config.json'
)
# Training loop
for step, batch in enumerate(data_loader):
# Move batch to the correct device
batch = to_device(batch, model_engine.local_rank)
# Forward pass
loss = model_engine(batch)
# Backward pass - use the model engine
model_engine.backward(loss)
# Optimizer step - use the model engine
model_engine.step()
Notice the changes in the training loop:
deepspeed.initialize returns a model_engine which replaces your original model for the main operations.loss.backward() call is replaced by model_engine.backward(loss).optimizer.step() call is replaced by model_engine.step().DeepSpeed handles all the underlying complexity of sharding, communication, and offloading based on the configuration you provide. This clean API allows you to experiment with different ZeRO stages and offloading strategies simply by changing the JSON configuration, without altering your core training logic.
Was this section helpful?
© 2026 ApX Machine LearningEngineered with