While DistributedDataParallel
(DDP) effectively scales training across multiple GPUs by replicating the model and averaging gradients, it fundamentally requires each GPU to hold the entire model, its gradients, and the optimizer states. This becomes a limiting factor when dealing with models containing billions of parameters, which can easily exceed the memory capacity of even high-end accelerators.
Fully Sharded Data Parallelism (FSDP) offers a solution by extending the principles of data parallelism while dramatically reducing the memory footprint on each GPU. Instead of replicating the entire model, FSDP shards, or partitions, the model's parameters, gradients, and optimizer states across the data parallel workers (GPUs).
At its core, FSDP ensures that each GPU in the data parallel group only holds a fraction (a "shard") of the model's parameters, gradients, and optimizer states at any given time. Full tensors are reconstructed temporarily only when needed for computation.
Here's a conceptual breakdown of the process during training:
Initialization: The model is wrapped with the FullyShardedDataParallel
module. During initialization, the parameters, gradients, and optimizer states are partitioned across the GPUs participating in the process group. Each GPU is responsible for managing its assigned shard.
Forward Pass:
all_gather
collective communication operation.Backward Pass:
reduce_scatter
operation is performed. This operation computes the average of the gradients across all GPUs and simultaneously shards the averaged result, sending each GPU only the portion (shard) of the gradient corresponding to the parameter shard it manages.reduce_scatter
.Optimizer Step:
This approach drastically reduces the peak memory required per GPU, as only the parameters for the currently executing layer and the shards of the full model, gradients, and optimizer states are stored persistently.
Comparison of memory allocation per GPU for Distributed Data Parallel (DDP) and Fully Sharded Data Parallel (FSDP). DDP replicates all components, while FSDP shards them.
PyTorch provides native support for FSDP through the torch.distributed.fsdp.FullyShardedDataParallel
class. Integrating it often involves wrapping your model definition.
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
import functools
# Assume distributed environment is initialized (rank, world_size, etc.)
# dist.init_process_group(backend="nccl")
# torch.cuda.set_device(local_rank) # local_rank obtained typically
class LargeTransformerBlock(nn.Module):
# Example submodule definition
def __init__(self, dim, ff_dim):
super().__init__()
self.layer_norm = nn.LayerNorm(dim)
self.attention = nn.MultiheadAttention(dim, num_heads=8) # Simplified
self.ffn = nn.Sequential(
nn.Linear(dim, ff_dim),
nn.ReLU(),
nn.Linear(ff_dim, dim)
)
def forward(self, x):
x = self.layer_norm(x + self.attention(x, x, x)[0])
x = x + self.ffn(x)
return x
class BigModel(nn.Module):
def __init__(self, num_layers, dim, ff_dim, vocab_size):
super().__init__()
self.embedding = nn.Embedding(vocab_size, dim)
self.layers = nn.ModuleList(
[LargeTransformerBlock(dim, ff_dim) for _ in range(num_layers)]
)
self.output_head = nn.Linear(dim, vocab_size)
def forward(self, x):
x = self.embedding(x)
for layer in self.layers:
x = layer(x)
x = self.output_head(x)
return x
# --- FSDP Setup ---
model = BigModel(num_layers=48, dim=2048, ff_dim=8192, vocab_size=50000).to(torch.cuda.current_device())
# Define an auto-wrap policy (optional but recommended for large models)
# This wraps submodules (like LargeTransformerBlock) based on size
auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy, min_num_params=1_000_000 # Example threshold
)
# Wrap the model with FSDP
fsdp_model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
# Other configuration options can be added here
# e.g., cpu_offload=CPUOffload(offload_params=True)
# e.g., mixed_precision=MixedPrecision(...)
# e.g., sharding_strategy=ShardingStrategy.SHARD_GRAD_OP
)
# --- Training Loop ---
# Optimizer must be constructed AFTER wrapping the model with FSDP
optimizer = torch.optim.AdamW(fsdp_model.parameters(), lr=1e-4)
# Example training step (simplified)
# for batch in dataloader:
# inputs = batch['input_ids'].to(torch.cuda.current_device())
# labels = batch['labels'].to(torch.cuda.current_device())
#
# optimizer.zero_grad()
# outputs = fsdp_model(inputs)
# loss = criterion(outputs.view(-1, vocab_size), labels.view(-1))
# loss.backward()
# optimizer.step()
Key points in the implementation:
nn.Module
.FSDP
. Note that the model should be moved to the target device before wrapping.FSDP
, passing fsdp_model.parameters()
to it. This ensures the optimizer is aware of the sharded parameters and states.auto_wrap_policy
is important. This policy tells FSDP how to recursively wrap submodules within your main model. Wrapping individual blocks (like Transformer layers) allows for finer-grained sharding and better overlap of communication and computation. size_based_auto_wrap_policy
is a common choice, wrapping modules that exceed a certain parameter count.FSDP offers several configuration options to tailor its behavior:
sharding_strategy
: Controls how aggressively parameters, gradients, and optimizer states are sharded.
ShardingStrategy.FULL_SHARD
: (Default) Shards parameters, gradients, and optimizer states. Offers maximum memory savings but potentially higher communication.ShardingStrategy.SHARD_GRAD_OP
: Shards gradients and optimizer states only. Parameters are replicated (similar to ZeRO Stage 2). Less memory saving than FULL_SHARD
but potentially lower communication overhead.ShardingStrategy.NO_SHARD
: Equivalent to DDP (replicates everything). Useful for debugging or baseline comparison.ShardingStrategy.HYBRID_SHARD
: Combines full sharding within a node and replication across nodes. Useful in multi-node scenarios.cpu_offload
: Configured via CPUOffload(offload_params=True/False)
. Allows offloading shards of parameters and gradients to CPU RAM when they are not actively used in computation. This further increases the feasible model size at the cost of significant communication overhead between CPU and GPU. Use this when GPU memory is the absolute bottleneck.
mixed_precision
: Configured via MixedPrecision(param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16)
. Integrates mixed-precision training directly within the FSDP wrapper, handling casting and gradient scaling automatically. It's generally recommended to use FSDP's built-in mixed precision rather than applying torch.cuda.amp.autocast
externally.
auto_wrap_policy
: As discussed, defines how nested modules are wrapped. Alternatives to size_based_auto_wrap_policy
include wrapping based on module type (transformer_auto_wrap_policy
) or manual wrapping.
backward_prefetch
: Controls prefetching of parameters for the backward pass to overlap communication and computation. Options like BackwardPrefetch.BACKWARD_PRE
(prefetch for the next layer during the current layer's backward) can improve performance.
While FSDP enables training significantly larger models, it introduces trade-offs:
all_gather
(forward) and reduce_scatter
(backward) operations introduce more communication volume compared to DDP's single all_reduce
in the backward pass. The performance impact depends heavily on the interconnect speed between GPUs/nodes. Faster interconnects (e.g., NVLink, InfiniBand) mitigate this overhead more effectively.auto_wrap_policy
are important to maximize this overlap.torch.utils.checkpoint.checkpoint
and FSDP has specific utilities (fsdp_checkpointing
) to apply this efficiently to wrapped modules.In summary, FSDP is a powerful technique for training extremely large models that do not fit into single GPU memory. By sharding parameters, gradients, and optimizer states across data parallel workers, it significantly lowers the per-GPU memory requirement. However, this comes at the cost of potentially increased communication overhead, making fast interconnects and careful configuration important for achieving good training performance. It represents a significant advancement in large-scale model training capabilities within PyTorch.
© 2025 ApX Machine Learning