Masterclass
While standard Data Parallelism (DP) allows scaling training across multiple GPUs by replicating the model and processing different data shards on each device, it quickly runs into memory limitations. Each GPU still needs to hold a full copy of the model parameters, gradients, and optimizer states. For models with billions of parameters, this replicated state consumes significant GPU memory, often exceeding the capacity of even high-end accelerators. This is where the Zero Redundancy Optimizer (ZeRO) offered by DeepSpeed provides a significant advantage. ZeRO is designed to eliminate this memory redundancy inherent in standard DP, allowing you to train much larger models or use larger batch sizes within the same hardware constraints.
ZeRO achieves this by partitioning the model states (optimizer states, gradients, and parameters) across the data-parallel processes instead of replicating them. It comes in three progressive stages, each offering greater memory savings at the cost of potentially increased communication.
Before examining ZeRO, let's visualize the memory consumption per GPU in standard DP. For a model with Ψ parameters, using a standard optimizer like Adam which typically stores the parameters, gradients, first-order momentum, and second-order variance, the memory required per GPU is roughly:
Activations also consume memory, but the model states are often the dominant factor for large models. In standard DP, all these states are replicated on each GPU. ZeRO systematically reduces these replications.
Memory components replicated on each GPU in standard Data Parallelism.
ZeRO Stage 1 tackles the first layer of redundancy: the optimizer states. Adam, for example, maintains momentum and variance buffers, which are often the same size as the model parameters themselves (or even larger if storing FP32 copies for mixed-precision training). Stage 1 partitions these optimizer states across the available data-parallel GPUs. Each GPU only holds a slice of the total optimizer state corresponding to its data-parallel rank.
During the optimizer step (optimizer.step()
), gradients still need to be reduced across all GPUs (like standard DP), but each GPU only updates the portion of the parameters for which it holds the corresponding optimizer state partition.
Benefits:
Configuration Example (DeepSpeed JSON):
{
"zero_optimization": {
"stage": 1
},
"fp16": {
"enabled": true
},
"train_batch_size": 32,
"gradient_accumulation_steps": 1
}
ZeRO Stage 2 goes further by partitioning both the optimizer states and the gradients across the data-parallel GPUs. During the backward pass, instead of using an AllReduce operation to sum gradients across all GPUs, Stage 2 uses a ReduceScatter operation. This operation calculates the sum but immediately scatters the results, so each GPU only receives the partition of the gradients corresponding to its partition of the optimizer states.
Benefits:
Configuration Example (DeepSpeed JSON):
{
"zero_optimization": {
"stage": 2
},
"fp16": {
"enabled": true
},
"train_batch_size": 32,
"gradient_accumulation_steps": 1
}
ZeRO Stage 3 is the most aggressive optimization, partitioning all three major model states: optimizer states, gradients, and the model parameters themselves. Each GPU only holds a shard of the parameters at any given time.
This requires more sophisticated management. During the forward and backward passes, each GPU needs access to the full parameters for a given layer computation. ZeRO Stage 3 handles this by dynamically gathering the required parameter shards from other GPUs just before they are needed for computation and discarding them immediately afterward to free memory.
Benefits:
Drawbacks:
Configuration Example (DeepSpeed JSON):
{
"zero_optimization": {
"stage": 3
},
"fp16": {
"enabled": true
},
"train_batch_size": 32,
"gradient_accumulation_steps": 1
}
Approximate reduction in model state memory per GPU across ZeRO stages relative to standard Data Parallelism. Actual savings depend on the optimizer and precision used.
Using ZeRO involves wrapping your model, optimizer, and potentially data loader with DeepSpeed's initialize
function. You provide the configuration details (often via a JSON file) to this function.
import torch
import deepspeed
# Assume model, optimizer, dataloader, and args (containing deepspeed_config path) are defined
# Initialize DeepSpeed
model_engine, optimizer, _, _ = deepspeed.initialize(
model=model,
optimizer=optimizer,
model_parameters=model.parameters(),
config_params=args.deepspeed_config # Path to JSON config file
)
# Training Loop Example Snippet
for step, batch in enumerate(dataloader):
# Move batch to device managed by DeepSpeed
batch = {
k: v.to(model_engine.local_rank) for k, v in batch.items()
}
# Forward pass
outputs = model_engine(**batch)
loss = outputs.loss
# Backward pass managed by DeepSpeed
model_engine.backward(loss)
# Optimizer step managed by DeepSpeed
model_engine.step()
# Logging, checkpointing, etc.
DeepSpeed handles the partitioning, communication (gradient reduction, parameter gathering), and optimizer steps according to the specified ZeRO stage in the configuration file.
ZeRO-Offload
variants within Stage 3, which can offload partitions to CPU RAM or NVMe storage for even larger models, albeit at the cost of slower access times.Experimentation is often necessary to find the optimal stage for your specific model, hardware configuration, and performance requirements. Start with Stage 1 or 2 and move to Stage 3 if memory constraints demand it, carefully monitoring training throughput to assess the impact of increased communication. ZeRO provides a powerful set of tools to overcome memory barriers in large-scale model training, making it a fundamental technique for building state-of-the-art LLMs.
© 2025 ApX Machine Learning