FullyShardedDataParallel (FSDP) offers a different approach to distributed training compared to DistributedDataParallel (DDP), requiring more than a simple class substitution. While their API surfaces appear similar, the underlying mechanics of parameter management differ fundamentally. In DDP, the model wrapper primarily handles gradient synchronization (all-reduce) at the bucket level. In FSDP, the wrapper takes ownership of the parameters, physically partitioning the tensors across the process group.This section details the programmatic implementation of FSDP, focusing on the ShardingStrategy configuration and the MixedPrecision policies required to stabilize training at scale.The FSDP Constructor InterfaceThe entry point for sharding is the FullyShardedDataParallel class. Unlike DDP, which accepts a model that already resides on the target GPU, FSDP often wraps models residing on the CPU to avoid immediate Out-Of-Memory (OOM) errors during initialization.The constructor signature exposes controls for the ZeRO stages discussed in the theoretical analysis. The most significant argument is sharding_strategy, which dictates how the model state is partitioned.from torch.distributed.fsdp import ( FullyShardedDataParallel as FSDP, ShardingStrategy, MixedPrecision, BackwardPrefetch, ) # Basic wrapping structure model = FSDP( module, sharding_strategy=ShardingStrategy.FULL_SHARD, mixed_precision=mixed_precision_policy, device_id=torch.cuda.current_device() )Mapping ZeRO Stages to Sharding StrategiesPyTorch maps the mathematical definitions of ZeRO optimization to the ShardingStrategy enum. Selecting the correct strategy depends on the balance between memory constraints and communication bandwidth.FULL_SHARD (ZeRO Stage 3): This is the default behavior. Parameters, gradients, and optimizer states are sharded. Parameters are only materialized (gathered) on the GPU during the forward and backward passes of the specific layer being computed, then immediately freed. This offers maximum memory savings ($$1/N$$ scaling) but incurs high communication overhead due to frequent AllGather operations.SHARD_GRAD_OP (ZeRO Stage 2): Gradients and optimizer states are sharded, but parameters remain replicated. This strategy avoids the communication overhead of gathering weights during the forward pass but requires enough VRAM to hold the full model parameters ($$\Psi$$).NO_SHARD (DDP equivalent): This mode replicates the behavior of DDP within the FSDP API. It is useful for debugging or when scaling down to a small number of GPUs where sharding overhead outweighs the benefits.The following diagram illustrates the memory allocation differences between these strategies across a two-GPU setup.digraph G { rankdir=TB; node [shape=rect, style=filled, fontname="Arial", fontsize=10]; splines=ortho; bgcolor="transparent"; subgraph cluster_0 { label="Standard DDP / NO_SHARD"; style=dashed; color="#adb5bd"; subgraph cluster_gpu0_ddp { label="GPU 0"; color="#dee2e6"; style=filled; node [width=1.5]; P0 [label="Params (Full)", fillcolor="#a5d8ff"]; G0 [label="Grads (Full)", fillcolor="#ffc9c9"]; O0 [label="Optim (Full)", fillcolor="#b2f2bb"]; } subgraph cluster_gpu1_ddp { label="GPU 1"; color="#dee2e6"; style=filled; node [width=1.5]; P1 [label="Params (Full)", fillcolor="#a5d8ff"]; G1 [label="Grads (Full)", fillcolor="#ffc9c9"]; O1 [label="Optim (Full)", fillcolor="#b2f2bb"]; } } subgraph cluster_1 { label="FSDP FULL_SHARD (ZeRO-3)"; style=dashed; color="#adb5bd"; subgraph cluster_gpu0_fsdp { label="GPU 0"; color="#dee2e6"; style=filled; node [width=1.5]; P0_s [label="Params [Shard 0]", fillcolor="#4dabf7"]; G0_s [label="Grads [Shard 0]", fillcolor="#ff8787"]; O0_s [label="Optim [Shard 0]", fillcolor="#69db7c"]; } subgraph cluster_gpu1_fsdp { label="GPU 1"; color="#dee2e6"; style=filled; node [width=1.5]; P1_s [label="Params [Shard 1]", fillcolor="#4dabf7"]; G1_s [label="Grads [Shard 1]", fillcolor="#ff8787"]; O1_s [label="Optim [Shard 1]", fillcolor="#69db7c"]; } } }Memory allocation comparison between standard replication and full sharding. In FULL_SHARD, both the storage and computation of updates are distributed, reducing per-device footprint linearly with the size.Mixed Precision ConfigurationUnlike torch.cuda.amp, which typically requires an external scaler and autocast context manager, FSDP integrates mixed precision directly into the sharding logic. This integration is necessary because FSDP must know the precision format before communicating tensors across ranks. If the communication occurs in FP32 while the computation is in BF16, bandwidth is wasted.The MixedPrecision config class controls three specific data types:param_dtype: The type used for model parameters during forward and backward computation.reduce_dtype: The type used for gradient reduction (communication).buffer_dtype: The type for buffers (e.g., batch norm statistics).For modern LLM training on Ampere or Hopper architectures, bfloat16 is the standard. It preserves the exponent range of FP32, avoiding the underflow issues common with FP16, often eliminating the need for loss scaling.import torch from torch.distributed.fsdp import MixedPrecision # Define the policy for BFloat16 training bf16_policy = MixedPrecision( param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, # Reduces communication volume by 50% buffer_dtype=torch.bfloat16, )When param_dtype is set to bfloat16, FSDP keeps the master weights in FP32 (if they exist) but casts them to BFloat16 before the forward pass. This aligns with the "Master Weights" concept in mixed-precision training, ensuring that small weight updates in the optimizer step are not lost due to precision truncation.Flat Parameters and the Forward PassWhen you wrap a module with FSDP, the original parameters (e.g., model.layer1.weight) are replaced by a generic FlatParameter. This is a single 1D tensor that views the storage of multiple original parameters.This flattening improves memory access patterns and communication efficiency. Instead of launching hundreds of small NCCL kernels for individual weight matrices, FSDP aggregates them into larger chunks.However, this introduces a constraint: you cannot access model.layer1.weight directly after wrapping, as the attribute effectively no longer exists in its original form on the device. Accessing it requires using the context manager FSDP.summon_full_params(model), which we will address in the checkpointing section.Complete Implementation ExampleThe following example demonstrates an initialization pattern. It sets up the process group, defines a model, and wraps it with FULL_SHARD and BFloat16 precision.import os import torch import torch.nn as nn import torch.distributed as dist from torch.distributed.fsdp import ( FullyShardedDataParallel as FSDP, ShardingStrategy, MixedPrecision, StateDictType, ) def setup(rank, world_size): os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '12355' # Initialize the process group with NCCL backend for GPUs dist.init_process_group("nccl", rank=rank, world_size=world_size) torch.cuda.set_device(rank) def cleanup(): dist.destroy_process_group() def train_fsdp_step(rank, world_size): setup(rank, world_size) # 1. Define Model (Standard PyTorch) # In a real scenario, this is likely a Transformer model = nn.Sequential( nn.Linear(1024, 4096), nn.ReLU(), nn.Linear(4096, 1024) ) # 2. Define Policies bf16_policy = MixedPrecision( param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, buffer_dtype=torch.bfloat16, ) # 3. Wrap with FSDP # Note: We move to device_id implicitly via the device_id arg. # For massive models, load on 'meta' device or CPU first. fsdp_model = FSDP( model, sharding_strategy=ShardingStrategy.FULL_SHARD, mixed_precision=bf16_policy, device_id=torch.cuda.current_device(), ) # 4. Optimizer Initialization # Important: Optimizer must be initialized AFTER wrapping # because FSDP modifies the parameter structure. optimizer = torch.optim.AdamW(fsdp_model.parameters(), lr=1e-3) # 5. Training Loop # Generate dummy data on the correct device inputs = torch.randn(64, 1024).to(rank).bfloat16() optimizer.zero_grad() output = fsdp_model(inputs) loss = output.sum() loss.backward() optimizer.step() print(f"Rank {rank} step complete. Loss: {loss.item()}") cleanup() # This function would be spawned via torch.multiprocessing in a script # mp.spawn(train_fsdp_step, args=(world_size,), nprocs=world_size)Device Initialization and the device_idIn the example above, the device_id argument is critical. If omitted, FSDP might attempt to initialize shards on the CPU or default GPU, leading to silent performance degradation or placement errors.When working with models that fit on a single GPU (like the example), wrapping an already materialized model is acceptable. However, for models approaching the terabyte scale, materializing the full model on CPU RAM before sharding is impossible. In those cases, we utilize delayed initialization with the meta device, allowing FSDP to shard the parameters without ever allocating the full model state on any single device. This advanced initialization pattern is the subject of the next chapter.