As your models and datasets grow in complexity and size, training them on a single GPU or CPU can become prohibitively slow. Distributed training allows you to parallelize this workload across multiple processing units, be they GPUs on a single machine or spread across several machines. This not only accelerates training but also enables you to work with models or batch sizes that would otherwise exceed the memory capacity of a single device. If you've used tf.distribute.Strategy
in TensorFlow, you're already familiar with the benefits and general ideas behind distributed training. PyTorch offers its own powerful tools, primarily within the torch.distributed
package, to achieve similar outcomes.
There are two primary strategies for distributing the training workload: data parallelism and model parallelism.
Data parallelism is the most common strategy. In this approach, the model is replicated on each available device (e.g., GPU). Each replica then processes a different subset (a shard or mini-batch) of the input data. The gradients computed by each replica are subsequently aggregated, and the model weights are updated synchronously across all replicas.
Diagram illustrating the data parallelism approach. The model is replicated on each GPU, processes a unique data shard, and gradients are aggregated before updating weights.
In PyTorch, data parallelism is primarily achieved using torch.nn.parallel.DistributedDataParallel
(DDP). This module wraps your existing model and handles the complexities of data distribution, gradient synchronization, and model updates across multiple processes, typically one per GPU. DDP is favored over the older torch.nn.DataParallel
(DP) because DDP uses multiprocessing, which avoids Python's Global Interpreter Lock (GIL) limitations and generally offers better performance, especially for models with significant Python overhead or when using multiple nodes.
The workflow for DDP usually involves:
DistributedDataParallel
.DistributedSampler
with your DataLoader
to ensure each process receives a unique portion of the dataset.This approach is similar to TensorFlow's tf.distribute.MirroredStrategy
for single-node, multi-GPU training or tf.distribute.MultiWorkerMirroredStrategy
for multi-node scenarios. The main idea is that each worker has a complete copy of the model and works on a part of the data.
Model parallelism is employed when a model is too large to fit into the memory of a single GPU. Instead of replicating the entire model on each device, different parts of the model (e.g., layers or blocks of layers) are placed on different devices. Data flows sequentially through these parts across the devices during the forward and backward passes.
Diagram illustrating model parallelism. Different parts of the model are placed on separate GPUs, and data flows between them.
Implementing model parallelism can be more complex than data parallelism because you need to manually manage the placement of model components and the transfer of intermediate activations and gradients between devices. PyTorch allows you to assign different parts of your model to different devices using .to(device)
. For example, you could put the embedding layers of a large NLP model on one GPU and subsequent transformer blocks on other GPUs.
While PyTorch provides the basic tools for manual model parallelism, more sophisticated forms, such as pipeline parallelism (where devices work on different stages of a pipeline simultaneously for different micro-batches), often benefit from specialized libraries like FairScale or DeepSpeed, which build upon PyTorch's primitives. The torch.distributed.rpc
module also provides a framework for more general distributed computation patterns, which can be used to implement custom model parallel strategies.
TensorFlow users might find similarities in manually placing tf.Variable
or layer computations on specific devices. Both frameworks require careful consideration of communication overhead, as data moving between GPUs can become a bottleneck.
The torch.distributed
package is the foundation for distributed training in PyTorch. Here are some of its central components:
Process Groups (torch.distributed.init_process_group
): Before any distributed operations can occur, processes must join a group. This function initializes the distributed environment. You need to specify:
backend
: The communication backend to use (e.g., gloo
, nccl
for GPU, or mpi
). nccl
is generally recommended for GPU-based training due to its high performance.init_method
: How processes discover each other (e.g., env://
for environment variable setup, or tcp://<master_addr>:<master_port>
).world_size
: The total number of processes participating in the job.rank
: A unique identifier for the current process, from 0 to world_size - 1
.Communication Primitives: torch.distributed
provides several functions for collective communication among processes:
all_reduce(tensor, op=ReduceOp.SUM)
: Reduces the tensor data across all machines. Each process ends up with the same final result (e.g., sum of all tensors). This is fundamental for averaging gradients in DDP.broadcast(tensor, src)
: Copies a tensor from the process with rank src
to all other processes in the group.scatter(tensor, scatter_list, src)
: Scatters a list of tensors to all processes in a group.gather(tensor, gather_list, dst)
: Gathers a list of tensors from all processes in a group to a destination process.torch.nn.parallel.DistributedDataParallel
(DDP): As mentioned, this is the workhorse for data parallelism. It wraps your model and handles:
DataLoader
with a DistributedSampler
).all_reduce
operation on gradients during the backward pass.Launch Utilities:
torch.multiprocessing.spawn(fn, args=(), nprocs=None, ...)
: A utility to spawn nprocs
processes that will run the target function fn
. Often used for single-node multi-GPU training.torchrun
(formerly python -m torch.distributed.launch
): A command-line utility provided by PyTorch to launch distributed training jobs, especially useful for multi-node setups. It handles setting up environment variables like MASTER_ADDR
, MASTER_PORT
, WORLD_SIZE
, and RANK
for each process.This is a common scenario where you have one machine with multiple GPUs.
torch.distributed
and torch.multiprocessing
.rank
and world_size
as arguments.dist.init_process_group()
with the appropriate backend (nccl
), rank, and world size.torch.cuda.set_device(rank)
.model.to(rank)
.DistributedDataParallel
: model = DDP(model, device_ids=[rank])
.torch.utils.data.distributed.DistributedSampler
with your DataLoader
to ensure each process gets a unique part of the data.dist.destroy_process_group()
at the end.mp.spawn()
in your main execution block (if __name__ == '__main__':
) to launch the training function across multiple processes.Here's a simplified structure:
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
# Assume MyModel and MyDataset are defined elsewhere
def setup(rank, world_size):
# For TCP initialization
# os.environ['MASTER_ADDR'] = 'localhost'
# os.environ['MASTER_PORT'] = '12355'
# dist.init_process_group("nccl", init_method='env://', rank=rank, world_size=world_size)
# Simpler initialization for single-node using a file (alternative to env variables)
# Ensure the file path is accessible and unique per job
init_file = "file:///tmp/my_shared_file_for_dist_init"
dist.init_process_group(backend="nccl", init_method=init_file,
world_size=world_size, rank=rank)
torch.cuda.set_device(rank)
def cleanup():
dist.destroy_process_group()
def train_fn(rank, world_size, epochs):
print(f"Running DDP on rank {rank}.")
setup(rank, world_size)
# Create model and move it to GPU with id rank
model = MyModel().to(rank)
ddp_model = DDP(model, device_ids=[rank], output_device=rank) # output_device can be useful
# Dummy dataset for illustration
dataset = MyDataset(...)
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
dataloader = DataLoader(dataset, batch_size=64, sampler=sampler, num_workers=2) # num_workers per process
optimizer = torch.optim.SGD(ddp_model.parameters(), lr=0.01)
for epoch in range(epochs):
sampler.set_epoch(epoch) # Important for shuffling with DistributedSampler
for batch_idx, (data, target) in enumerate(dataloader):
data, target = data.to(rank), target.to(rank)
optimizer.zero_grad()
output = ddp_model(data)
loss = torch.nn.functional.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % 10 == 0 and rank == 0: # Log from rank 0
print(f"Epoch: {epoch} | Batch: {batch_idx} | Loss: {loss.item()}")
cleanup()
if __name__ == "__main__":
world_size = torch.cuda.device_count() # Number of GPUs
epochs = 10
# mp.spawn(train_fn, args=(world_size, epochs), nprocs=world_size, join=True)
# Note: For actual execution, replace MyModel and MyDataset with real implementations
# and uncomment mp.spawn.
print(f"Example setup for {world_size} GPUs. To run, implement MyModel, MyDataset and uncomment mp.spawn.")
You can control which GPUs are visible to PyTorch using the CUDA_VISIBLE_DEVICES
environment variable. For instance, CUDA_VISIBLE_DEVICES=0,1
would make only GPU 0 and GPU 1 available.
Training across multiple machines introduces more setup complexity, primarily related to network communication and process discovery. torchrun
is the recommended tool for this.
You'd typically launch your training script on each node using torchrun
. Key parameters for torchrun
include:
--nnodes
: Total number of nodes.--nproc_per_node
: Number of processes (usually GPUs) per node.--rdzv_id
: A unique job ID.--rdzv_backend
: Rendezvous backend (e.g., c10d
for TCP-based).--rdzv_endpoint
: Endpoint for the rendezvous server (e.g., MASTER_NODE_IP:PORT
). One node acts as the master for coordination.The PyTorch script itself (like train_fn
above) largely remains the same. torchrun
sets up the environment variables (MASTER_ADDR
, MASTER_PORT
, WORLD_SIZE
, RANK
) that init_process_group(backend="nccl", init_method="env://")
uses to establish communication. Cluster management systems like Slurm or Kubernetes often have integrations or utilities to simplify launching torchrun
across nodes.
tf.distribute.Strategy
If you've used TensorFlow's tf.distribute.Strategy
, you'll find parallels:
tf.distribute.MirroredStrategy
: This is highly analogous to PyTorch's DistributedDataParallel
(DDP) on a single node with multiple GPUs. Both replicate the model on each GPU and use AllReduce for gradient synchronization. The TensorFlow API might abstract away some of the explicit process group setup, integrating it more directly into the Strategy
scope.tf.distribute.MultiWorkerMirroredStrategy
: This corresponds to DDP in a multi-node setting. Both require coordinating processes across machine boundaries. TensorFlow's strategy relies on TF_CONFIG
environment variable for configuration, while PyTorch often uses torchrun
or manual setup of similar environment variables (MASTER_ADDR
, etc.).tf.distribute.ParameterServerStrategy
: This involves dedicated parameter servers storing variables, while workers compute gradients. While PyTorch's DDP is more akin to AllReduce architectures, torch.distributed.rpc
can be used to build parameter server-style training, though it's less common for typical deep learning workloads compared to DDP.tf.distribute.experimental.TPUStrategy
or manual placement): TensorFlow's support for model parallelism, especially on TPUs via TPUStrategy
, can involve sophisticated model sharding. Manually, tf.device
scopes are used, similar to PyTorch's .to(device)
.The primary difference often lies in the explicitness of setup. PyTorch's torch.distributed
and DDP give you fine-grained control but require a bit more boilerplate for initialization and process launching, especially when compared to the context-manager style of tf.distribute.Strategy
. However, the underlying principles of distributing data and synchronizing gradients are fundamentally similar.
DistributedSampler
: It's essential to use torch.utils.data.distributed.DistributedSampler
. This sampler ensures that each process loads a unique, non-overlapping subset of the dataset. Remember to call sampler.set_epoch(epoch)
at the beginning of each epoch if you want shuffling to work correctly across epochs.rank == 0
) should save the model checkpoint to avoid race conditions or multiple writes. You can access the underlying model from DDP using ddp_model.module.state_dict()
.state_dict
onto the CPU first, then map it to the correct GPU for each rank to avoid GPU memory issues on rank 0 if the model is large:
# In your training function, after setup(rank, world_size)
map_location = {'cuda:%d' % 0: 'cuda:%d' % rank} # Map to current rank's GPU
checkpoint = torch.load(PATH_TO_CHECKPOINT, map_location=map_location)
model.load_state_dict(checkpoint['model_state_dict'])
# Potentially load optimizer state, epoch, etc.
Alternatively, rank 0 can load and then broadcast the state dict to other processes.SyncBatchNorm
(torch.nn.SyncBatchNorm
) can be used to synchronize statistics across all processes, which can be beneficial if per-GPU batch sizes are very small. DDP automatically converts BatchNorm
layers to SyncBatchNorm
if you request it or if it detects it's needed.if rank == 0:
to avoid cluttered output.torch.distributed.barrier()
can be used to synchronize processes at certain points for debugging.torch.manual_seed(seed)
.By understanding these approaches and components, you can effectively scale your PyTorch training workflows, much like you would with tf.distribute.Strategy
in TensorFlow, enabling you to tackle larger and more complex machine learning problems.
© 2025 ApX Machine Learning