Masterclass
Data Parallelism (DP) stands as the most straightforward and commonly employed strategy for distributing the computational load of training deep learning models, including large language models. As highlighted in the chapter introduction, training large models often encounters bottlenecks related to computation time and memory constraints on single devices. Data parallelism directly addresses the computation time aspect by processing different parts of the data simultaneously across multiple devices.
The core principle is simple: replicate the entire model onto each available processing unit (like a GPU), and then divide each global data batch into smaller pieces, called micro-batches. Each device processes its assigned micro-batch independently.
Let's break down the typical steps involved in one training iteration using data parallelism with K devices:
AllReduce
. The AllReduce
operation sums (or averages) the gradients gi from all K devices and distributes the resulting synchronized gradient gsync back to every device. Often, averaging is used:
gsync=K1i=1∑Kgi
This cycle repeats for each batch in the training dataset.
Data Parallelism workflow: The global batch is split, processed in parallel on devices holding model replicas, gradients are synchronized via AllReduce, and synchronized updates are applied to each replica.
DistributedDataParallel
) provide high-level abstractions that make implementing data parallelism relatively straightforward, often requiring only minor modifications to standard single-device training code.Despite its advantages, data parallelism has a significant limitation, especially for the large language models this course focuses on:
AllReduce
operation requires communicating gradients across all devices involved. The amount of data transferred is proportional to the size of the model parameters. As the number of devices (K) increases, or if the interconnect bandwidth between devices is limited, this synchronization step can become a significant bottleneck, diminishing the speedup gained from parallel computation. The time taken for AllReduce
can sometimes dominate the computation time, especially for smaller models or faster computations per device.PyTorch's torch.nn.parallel.DistributedDataParallel
module is the standard way to implement data parallelism in a multi-process setting, which is generally preferred over the older torch.nn.DataParallel
for performance and flexibility. Here's an outline:
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
import os
# Assume a simple model definition exists:
# class MyLargeModel(nn.Module): ...
def setup(rank, world_size):
"""Initialize the process group."""
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# Initialize the process group
# Use 'nccl' backend for NVIDIA GPUs
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def cleanup():
"""Destroy the process group."""
dist.destroy_process_group()
def train_worker(rank, world_size, model_args, data_args, train_args):
"""Main training function for a single worker process."""
print(f"Running DDP training on rank {rank}.")
setup(rank, world_size)
# Instantiate the model and move it to the assigned GPU
model = MyLargeModel(**model_args).to(rank)
# Wrap the model with DDP
# This handles gradient synchronization automatically
ddp_model = DDP(model, device_ids=[rank])
# Prepare dataset and DistributedSampler
# DistributedSampler ensures each process gets a different
# slice of data
dataset = YourDataset(**data_args)
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
# Use pin_memory=True and num_workers > 0 for optimized data loading
dataloader = DataLoader(
dataset,
batch_size=train_args['micro_batch_size'],
sampler=sampler,
pin_memory=True,
num_workers=4
)
optimizer = torch.optim.AdamW(
ddp_model.parameters(),
lr=train_args['learning_rate']
)
# --- Simplified Training Loop ---
ddp_model.train()
for epoch in range(train_args['num_epochs']):
sampler.set_epoch(epoch) # Important for shuffling with DDP
for batch in dataloader:
inputs = batch['input_ids'].to(rank)
labels = batch['labels'].to(rank)
optimizer.zero_grad()
outputs = ddp_model(inputs, labels=labels) # Forward pass
loss = outputs.loss
# Backward pass - DDP automatically averages gradients
loss.backward()
# Optimizer step - applies update based on averaged grads
optimizer.step()
if rank == 0: # Log only on the main process
print(f"Epoch: {epoch}, Loss: {loss.item()}")
# --- End Simplified Loop ---
cleanup()
if __name__ == '__main__':
# Example configuration (replace with actual args)
world_size = torch.cuda.device_count() # e.g., 4 GPUs
model_args = {'vocab_size': 50257, 'hidden_size': 768, 'num_layers': 12}
data_args = {'data_path': '/path/to/data'}
train_args = {
'micro_batch_size': 8,
'learning_rate': 1e-4,
'num_epochs': 3
}
# Note: Global batch size = micro_batch_size * world_size
# Spawn worker processes
mp.spawn(
train_worker,
args=(world_size, model_args, data_args, train_args),
nprocs=world_size,
join=True
)
In this sketch:
setup
initializes the distributed environment. Each process gets a unique rank
from 0 to world_size - 1
.DistributedSampler
is used with the DataLoader
to ensure each process gets non-overlapping data partitions.ddp_model = DDP(model, device_ids=[rank])
. DDP intercepts the backward pass, performs the AllReduce
operation on the gradients automatically, and ensures all processes have the averaged gradients before the optimizer.step()
is called.Data parallelism is a fundamental technique for scaling training throughput. However, its memory limitations necessitate exploring other strategies like tensor and pipeline parallelism, especially when dealing with the enormous scale of modern large language models. Often, the most effective approaches combine data parallelism with these other techniques, which we will examine next.
© 2025 ApX Machine Learning