Theory gives us the map, but implementation builds the engine. To facilitate advanced distributed model training, PyTorch offers Fully Sharded Data Parallel, or FSDP, as its native solution. FSDP helps manage complexities arising from data and model parallelism. It provides a powerful, well-integrated API for applying advanced memory-saving techniques, including those comparable to DeepSpeed's ZeRO.
This practical guides you through configuring and running a distributed training job for a transformer model on a multi-GPU system. You will not just execute a script; you will learn the mechanics of setting up the distributed environment, correctly wrapping your model with FSDP, managing sharded checkpoints, and launching the job. This hands-on experience is important for moving from distributed training principles to production-ready implementation.
Before we write any code, we must set up our environment. This lab requires a system with at least two GPUs, although the code will function correctly on a single GPU for syntax checking. You will need PyTorch and the transformers library from Hugging Face for a pre-built model and tokenizer.
Install the necessary libraries using pip:
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
pip install transformers datasets accelerate
We will launch our training script using torchrun, PyTorch's standard utility for initiating distributed jobs. torchrun automatically manages the necessary environment variables for each process:
WORLD_SIZE: The total number of processes (GPUs) participating in the job.RANK: The unique global ID for the current process, from 0 to WORLD_SIZE - 1.LOCAL_RANK: The unique local ID for the current process on a given machine.Understanding these variables is important for tasks like printing logs or saving checkpoints from only one process (typically rank 0).
FSDP achieves its memory efficiency by sharding the model's parameters, their corresponding gradients, and the optimizer states across all the GPUs in the data-parallel group. During runtime, each GPU only holds a fraction of the total model state.
When a layer is needed for computation in the forward pass, each GPU gathers the necessary parameter shards from all other GPUs to reconstruct the full layer. After the computation, the full layer is discarded, freeing the memory. A similar, reversed process occurs during the backward pass.
Diagram of the FSDP
all_gatheroperation during a forward pass. Each GPU holds only its shard of the model's state and gathers the remaining shards from peers just-in-time to reconstruct the full layer for computation.
FSDP offers several ShardingStrategy options that control this behavior, providing a trade-off between memory savings and communication overhead. The two primary strategies are:
FULL_SHARD: This shards model parameters, gradients, and optimizer states, offering the maximum memory savings. It is analogous to ZeRO-3.SHARD_GRAD_OP: This shards only gradients and optimizer states, keeping a full copy of the model parameters on each GPU. It saves less memory but reduces communication. It is analogous to ZeRO-2.For this practical, we will use FULL_SHARD to maximize memory efficiency.
The first step in any PyTorch distributed script is to initialize the process group. This function establishes the communication backend (like nccl for NVIDIA GPUs) and allows the processes to discover each other.
Create a file named train_fsdp.py and add the setup function.
import os
import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
import functools
def setup():
"""Initialize the distributed environment."""
dist.init_process_group("nccl")
# Set the device for the current process.
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
def cleanup():
"""Clean up the distributed environment."""
dist.destroy_process_group()
We'll use a GPT2 model and its tokenizer from the transformers library. For a real training run, you would use a large dataset; here, we'll create a simple dummy dataset for demonstration.
from transformers import AutoModelForCausalLM, AutoTokenizer
def get_model_and_tokenizer():
"""Load a pre-trained model and tokenizer."""
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Add a padding token if it doesn't exist
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
model = AutoModelForCausalLM.from_pretrained(model_name)
model.resize_token_embeddings(len(tokenizer))
return model, tokenizer
def get_dummy_dataloader(tokenizer, batch_size=4):
"""Create a dummy dataloader for demonstration."""
dummy_data = ["This is a test sentence for FSDP." for _ in range(100)]
encoded_data = tokenizer(dummy_data, return_tensors="pt", padding=True, truncation=True, max_length=128)
dataset = torch.utils.data.TensorDataset(encoded_data.input_ids, encoded_data.attention_mask)
# Sampler is important for distributing data across GPUs
sampler = torch.utils.data.distributed.DistributedSampler(dataset)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=sampler)
return dataloader
Notice the use of DistributedSampler. This is a critical component that ensures each GPU receives a unique, non-overlapping slice of the data for each epoch.
This is where FSDP is configured. Instead of wrapping the entire model in one large FSDP unit, it's more efficient to wrap individual layers or blocks. This allows FSDP to free the memory for a layer's parameters immediately after it's used in the forward and backward passes.
The auto_wrap_policy makes this easy. We'll use a size_based_auto_wrap_policy which automatically wraps any submodule that exceeds a certain number of parameters.
# (This code goes inside your main training function)
# Define the auto wrap policy
# We wrap transformer blocks that are larger than 1M parameters.
# Adjust this value based on your model architecture.
auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy,
min_num_params=1_000_000,
)
# Get the local rank
local_rank = int(os.environ["LOCAL_RANK"])
model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
sharding_strategy=torch.distributed.fsdp.ShardingStrategy.FULL_SHARD,
device_id=torch.cuda.current_device(),
# cpu_offload=CPUOffload(offload_params=True) # Optional: Offload to CPU
)
The device_id argument is important; it tells FSDP which GPU to move the model shards to. The commented-out cpu_offload parameter shows how you could offload parameters to CPU RAM if you are extremely memory-constrained, at the cost of slower performance due to PCIe data transfers.
The training loop itself is nearly identical to a standard, non-distributed PyTorch loop. The optimizer must be defined after the model has been wrapped in FSDP, as FSDP replaces the model's parameters with its own FlatParameter objects.
# (Inside your main training function, after wrapping the model)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
model.train()
for epoch in range(1, 3): # Train for 2 epochs
dataloader.sampler.set_epoch(epoch) # Ensure shuffling is different each epoch
for batch_idx, (input_ids, attention_mask) in enumerate(dataloader):
input_ids = input_ids.to(local_rank)
attention_mask = attention_mask.to(local_rank)
optimizer.zero_grad()
# The forward pass automatically handles the all-gather
outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids)
loss = outputs.loss
# The backward pass handles the reduce-scatter
loss.backward()
optimizer.step()
if batch_idx % 10 == 0 and dist.get_rank() == 0:
print(f"Epoch: {epoch}/{2} | Batch: {batch_idx} | Loss: {loss.item():.4f}")
We use dist.get_rank() == 0 to ensure that the print statement is only executed by a single process, preventing a storm of identical log messages.
Saving and loading with FSDP requires a specific approach because the model state is sharded across all GPUs. You must decide whether to save a sharded checkpoint (faster, but requires the same WORLD_SIZE to load) or a full, consolidated checkpoint (more portable).
To save a full checkpoint, we need to gather the entire model state onto a single rank (usually rank 0) before saving.
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
# --- Saving a full state dict ---
if dist.get_rank() == 0:
print("Saving consolidated model checkpoint...")
# Use a context manager to get the full state dict
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):
full_state_dict = model.state_dict()
# Only rank 0 saves the file
if dist.get_rank() == 0:
torch.save(full_state_dict, "full_model_checkpoint.pt")
# --- Loading a full state dict ---
# To load, you first wrap the model in FSDP and then load the state.
# model = FSDP(...) # Initialize FSDP model as before
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):
# Load the checkpoint on CPU first to avoid OOM on a single GPU
checkpoint = torch.load("full_model_checkpoint.pt", map_location="cpu")
model.load_state_dict(checkpoint)
This pattern ensures that the state is correctly gathered from all shards before saving and correctly scattered back when loading.
Now, assemble all the pieces into the train_fsdp.py script. The main execution block should look like this:
def main():
setup()
rank = int(os.environ["RANK"])
model, tokenizer = get_model_and_tokenizer()
dataloader = get_dummy_dataloader(tokenizer)
# ... [FSDP wrapping logic here] ...
# ... [Optimizer definition here] ...
# ... [Training loop here] ...
# ... [Checkpoint saving logic here] ...
cleanup()
if __name__ == '__main__':
main()
Launch the training job from your terminal. This command tells torchrun to launch 2 processes on the current machine, each running the train_fsdp.py script.
torchrun --nproc_per_node=2 train_fsdp.py
If you have 4 GPUs, you would use --nproc_per_node=4. The script will execute, and you will see the loss being printed by rank 0. You can monitor your GPU memory usage with watch -n 1 nvidia-smi. You should see that the memory used on each GPU is significantly less than what would be required to hold the entire GPT-2 model.
By completing this practical, you have not only executed a distributed training job but have also engaged with the core mechanics of FSDP: initialization, auto-wrapping policies, the DistributedSampler, and state dictionary management. These skills are directly applicable to training large-scale models in a production environment.
Was this section helpful?
init_process_group, torchrun, and environment variables.© 2026 ApX Machine LearningEngineered with