Alright, let's translate the theory of DistributedDataParallel (DDP) into practice. This section guides you through converting a standard single-GPU PyTorch training script into one that leverages DDP for multi-GPU training on a single machine. We assume you have a functional single-GPU script ready. The goal is not to write a complete state-of-the-art model trainer here, but to illustrate the specific modifications required to enable DDP.
The core idea behind DDP is simple: replicate the model on each available GPU, feed each replica a different slice of the input data batch, compute gradients independently on each GPU, and then average these gradients across all GPUs before updating the model parameters. This ensures all model replicas remain synchronized.
We'll break down the conversion process into manageable steps:
DistributedSampler
to give each process a unique portion of the dataset.DistributedDataParallel
.torchrun
to start the distributed processes.Let's detail each step.
Every DDP script needs to initialize the distributed environment. This allows processes to discover each other and coordinate. We use torch.distributed.init_process_group
.
import torch
import torch.distributed as dist
import os
def setup(rank, world_size):
"""Initializes the distributed environment."""
os.environ['MASTER_ADDR'] = 'localhost' # Address of the master node
os.environ['MASTER_PORT'] = '12355' # An available port
# Initialize the process group
# Requires rank and world_size. Backend 'nccl' is recommended for NVIDIA GPUs.
dist.init_process_group("nccl", rank=rank, world_size=world_size)
print(f"Initialized process group for rank {rank} of {world_size} processes.")
def cleanup():
"""Destroys the process group."""
dist.destroy_process_group()
print("Destroyed process group.")
# --- In your main execution flow ---
# world_size = torch.cuda.device_count() # Assuming using all available GPUs
# rank = ... # This will be provided by the launcher (torchrun)
# setup(rank, world_size)
# ... training code ...
# cleanup()
rank
: A unique identifier for the current process (from 0 to world_size - 1
).world_size
: The total number of processes participating in the distributed job.backend
: The communication library to use. nccl
is highly optimized for NVIDIA GPUs. gloo
is an alternative for CPU or environments without nccl
.MASTER_ADDR
and MASTER_PORT
: These tell the processes where to find the primary process (rank 0) for initial coordination. localhost
is sufficient for single-node training.Note: When using torchrun
, rank
and world_size
(along with other variables like LOCAL_RANK
) are typically managed automatically and passed to your script. You often retrieve the rank via an argument parser or directly from environment variables if needed elsewhere.
Each process needs to operate on its assigned GPU. The typical convention is to use the local_rank
. The local_rank
is the GPU index within the current node. For single-node training, local_rank
is often the same as the global rank
, but relying on local_rank
is good practice for portability. torchrun
sets the LOCAL_RANK
environment variable.
# At the beginning of your training script or function:
local_rank = int(os.environ['LOCAL_RANK'])
torch.cuda.set_device(local_rank)
device = torch.device(f"cuda:{local_rank}")
# --- Example usage ---
# model = YourModel().to(device) # Move model to the assigned GPU
# data = data.to(device) # Move data to the assigned GPU
# labels = labels.to(device)
By setting torch.cuda.set_device(local_rank)
, subsequent CUDA operations and tensor allocations by that process will default to the correct GPU. Explicitly moving the model and data using .to(device)
is still required.
To ensure each GPU processes a unique subset of the data, replace the standard DataLoader
shuffling with torch.utils.data.distributed.DistributedSampler
.
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
# Assume 'train_dataset' is your torch.utils.data.Dataset instance
# rank and world_size are obtained after init_process_group
train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True)
# Important: Set shuffle=False in DataLoader, as DistributedSampler handles shuffling.
train_loader = DataLoader(
train_dataset,
batch_size=per_device_batch_size, # Batch size PER GPU
sampler=train_sampler,
num_workers=4, # Adjust as needed
pin_memory=True # Recommended for performance
)
# --- Inside the training loop ---
for epoch in range(num_epochs):
# Set the epoch for the sampler to ensure shuffling varies across epochs
train_sampler.set_epoch(epoch)
for batch_idx, (data, target) in enumerate(train_loader):
# ... rest of the training step ...
DistributedSampler
automatically splits the dataset indices among the processes (num_replicas=world_size
).shuffle=True
in the sampler ensures data is shuffled before partitioning.DataLoader
's batch_size
now refers to the batch size per process/GPU. The effective total batch size across all GPUs is per_device_batch_size * world_size
.sampler.set_epoch(epoch)
at the start of each epoch is important for proper shuffling behavior over multiple epochs.After creating your model and moving it to the correct device, wrap it using torch.nn.parallel.DistributedDataParallel
.
from torch.nn.parallel import DistributedDataParallel as DDP
# Assume 'model' is your nn.Module instance already moved to 'device'
# model = YourModel().to(device)
# Wrap the model
model = DDP(model, device_ids=[local_rank], output_device=local_rank)
# Now use 'model' for forward passes as usual.
# DDP handles gradient synchronization automatically during backward().
device_ids
: Specifies the GPU this process's model replica resides on. Typically [local_rank]
.output_device
: Specifies where the output of the forward pass should be gathered. Usually also local_rank
. DDP handles this internally.DDP works by adding hooks to the model's backward()
pass. When loss.backward()
is called, gradients are computed locally on each GPU, and then DDP triggers an all-reduce operation to sum/average gradients across all processes before updating the model parameters. This ensures all model replicas stay synchronized.
The core training logic (forward pass, loss calculation, optimizer.step()
) remains largely unchanged. However, consider these points:
train_sampler.set_epoch(epoch)
.dist.all_reduce
.# --- Inside the training loop after calculating loss ---
loss = criterion(outputs, target)
# Make a copy for aggregation, prevents modifying the loss tensor used for backward()
loss_tensor = torch.tensor([loss.item()], device=device)
# Sum loss values from all processes
dist.all_reduce(loss_tensor, op=dist.ReduceOp.SUM)
# Average the loss (divide by world size)
avg_loss = loss_tensor.item() / world_size
if rank == 0: # Log only on the main process
print(f"Epoch {epoch}, Batch {batch_idx}, Avg Loss: {avg_loss:.4f}")
# Note: Backward pass uses the original 'loss' tensor
loss.backward()
optimizer.step()
This example shows reducing the loss. You would do something similar for accuracy or other metrics. More sophisticated libraries like torchmetrics
often have built-in support for distributed environments.
Saving checkpoints (model state, optimizer state) should typically only be done by one process (usually rank 0) to prevent multiple processes from writing to the same file simultaneously. When saving a DDP-wrapped model, access the underlying model using .module
.
# --- Inside your saving logic ---
if rank == 0:
checkpoint = {
'epoch': epoch,
# Access the original model's state dict via .module
'model_state_dict': model.module.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
# Add any other necessary info
}
torch.save(checkpoint, f"model_epoch_{epoch}.pt")
print(f"Checkpoint saved at epoch {epoch} by rank {rank}.")
# --- Loading logic ---
# Ensure all processes load the same checkpoint before wrapping the model
map_location = {'cuda:%d' % 0: 'cuda:%d' % local_rank} # Map saved weights to current device
checkpoint = torch.load(checkpoint_path, map_location=map_location)
# Load state dict BEFORE wrapping with DDP
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# ... load other info ...
# After loading, move model to device and wrap with DDP
model.to(device)
model = DDP(model, device_ids=[local_rank], output_device=local_rank)
# Ensure all processes have loaded before continuing
dist.barrier()
Using dist.barrier()
after loading ensures that no process moves ahead until all processes have successfully loaded the checkpoint data, preventing potential race conditions.
Call dist.destroy_process_group()
at the end of your script or within a finally
block to release resources.
# --- At the very end of the main execution ---
# ... training finished ...
cleanup()
The standard way to launch a PyTorch DDP script is using the torchrun
utility (previously known as torch.distributed.launch
). It handles setting up environment variables (RANK
, LOCAL_RANK
, WORLD_SIZE
, MASTER_ADDR
, MASTER_PORT
) and spawning the processes.
Assuming your script is named train_ddp.py
and you want to use 2 GPUs on the current machine:
torchrun --standalone --nproc_per_node=2 train_ddp.py --arg1 value1 --arg2 value2
--standalone
: Indicates single-node training.--nproc_per_node
: The number of processes (and typically GPUs) to use on this node. Set this to the number of GPUs you want to utilize.train_ddp.py
: Your script name.--arg1 value1 ...
: Any command-line arguments your script expects.torchrun
will spawn nproc_per_node
copies of your script, each with the correct environment variables set, triggering the setup
function and the subsequent DDP logic within each process.
Here’s a skeleton combining the elements:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from torch.utils.data import DataLoader, Dataset, DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
import os
import argparse
# --- Dummy Model and Dataset ---
class ToyModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 1)
def forward(self, x):
return self.linear(x)
class ToyDataset(Dataset):
def __init__(self, size=1000):
self.size = size
self.features = torch.randn(size, 10)
self.labels = torch.randn(size, 1)
def __len__(self):
return self.size
def __getitem__(self, idx):
return self.features[idx], self.labels[idx]
# --- End Dummy ---
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355' # Ensure this port is free
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank) # Use global rank directly as local rank for simplicity here
def cleanup():
dist.destroy_process_group()
def train(rank, world_size, args):
setup(rank, world_size)
device = torch.device(f"cuda:{rank}")
# 1. Dataset and Sampler
dataset = ToyDataset(size=args.dataset_size)
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True)
# Effective batch size = args.batch_size * world_size
loader = DataLoader(dataset, batch_size=args.batch_size, sampler=sampler, num_workers=2, pin_memory=True)
# 2. Model
model = ToyModel().to(device)
model = DDP(model, device_ids=[rank])
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=args.lr)
print(f"Rank {rank} starting training...")
for epoch in range(args.epochs):
sampler.set_epoch(epoch) # Important for shuffling
epoch_loss = 0.0
num_batches = 0
for features, labels in loader:
features, labels = features.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(features)
loss = criterion(outputs, labels)
loss.backward() # DDP handles gradient sync here
optimizer.step()
# Aggregate loss for logging (optional but good practice)
loss_tensor = torch.tensor([loss.item()], device=device)
dist.all_reduce(loss_tensor, op=dist.ReduceOp.SUM)
epoch_loss += loss_tensor.item()
num_batches += 1
avg_epoch_loss = epoch_loss / (num_batches * world_size) # Average across all batches and processes
if rank == 0: # Log only from rank 0
print(f"Epoch {epoch+1}/{args.epochs}, Avg Loss: {avg_epoch_loss:.4f}")
# --- Checkpointing (Example) ---
if rank == 0 and (epoch + 1) % args.save_interval == 0:
checkpoint_path = f"model_epoch_{epoch+1}.pt"
torch.save(model.module.state_dict(), checkpoint_path)
print(f"Rank {rank} saved checkpoint to {checkpoint_path}")
dist.barrier() # Ensure all processes finish epoch before proceeding/saving
cleanup()
if rank == 0:
print("Training complete.")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--epochs', type=int, default=5, help='Number of training epochs')
parser.add_argument('--batch_size', type=int, default=64, help='Batch size PER GPU')
parser.add_argument('--lr', type=float, default=0.01, help='Learning rate')
parser.add_argument('--dataset_size', type=int, default=2048, help='Total dataset size')
parser.add_argument('--save_interval', type=int, default=2, help='Save checkpoint every N epochs')
# Note: rank, world_size, local_rank are typically set by the launcher (torchrun)
# We get them from the environment within the train function or setup.
args = parser.parse_args()
# torchrun sets these env variables
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
local_rank = int(os.environ["LOCAL_RANK"]) # Often used for device assignment
# Start the training function for the current process
train(rank, world_size, args) # Pass rank and world_size explicitly
To run this script using 2 GPUs:
torchrun --standalone --nproc_per_node=2 train_ddp.py --epochs 10 --batch_size 32
This practical exercise demonstrates the fundamental changes needed to adapt a single-process script for multi-GPU data-parallel training using DistributedDataParallel
. While real-world applications often involve more complex metric handling, logging, and checkpointing strategies, these core steps form the foundation for scaling your PyTorch training jobs. Remember to monitor GPU utilization (nvidia-smi
) and training time to observe the benefits of distributed training.
© 2025 ApX Machine Learning