Training contemporary deep learning models presents significant computational hurdles. Models with billions of parameters might not fit onto a single accelerator's memory, and iterating through terabytes of data on one machine can stretch training times from days to weeks or even months. Distributed computing offers a path forward, enabling the pooling of resources from multiple devices (like GPUs) across one or several machines to tackle these large scale problems. Before implementing specific PyTorch strategies like DistributedDataParallel
(DDP) or Fully Sharded Data Parallelism (FSDP), it's essential to grasp the fundamental vocabulary and communication patterns used in these distributed setups.
When discussing distributed training, several terms appear frequently. Understanding their precise meaning is important for configuring and debugging distributed jobs.
torch.distributed
package supports several backends:
Distributed training relies heavily on collective communication operations, where multiple processes synchronize and exchange data simultaneously. These are the primitives upon which higher level strategies like DDP are built. torch.distributed
provides functions for these operations:
Broadcast (torch.distributed.broadcast
): Sends a tensor from a single designated process (the src
rank) to all other processes in the group. This is commonly used at the start of training to ensure all workers begin with the exact same initial model parameters.
Reduce (torch.distributed.reduce
): Collects tensors from all processes in the group, applies a specified reduction operation (like SUM
, AVG
, MAX
, MIN
), and stores the result on a single destination process (the dst
rank).
All Reduce (torch.distributed.all_reduce
): Similar to Reduce, but the final result of the reduction operation is distributed back to all processes in the group. This is the foundation of DDP, where gradients computed independently on each worker are averaged across all workers, ensuring consistent model updates everywhere.
Scatter (torch.distributed.scatter
): Takes a list of tensors on a single source process (src
) and distributes one tensor from the list to each process in the group (including itself). The i th tensor in the list goes to the process with rank i.
Gather (torch.distributed.gather
): The inverse of Scatter. Each process sends its tensor to a designated destination process (dst
), which collects them into a list of tensors ordered by rank.
All Gather (torch.distributed.all_gather
): Similar to Gather, but every process in the group receives the concatenated list of tensors from all other processes. Useful when every worker needs the complete set of results from all other workers, for example, gathering embeddings computed in parallel.
Reduce Scatter (torch.distributed.reduce_scatter
): Performs an element wise reduction (like All Reduce) on a list of input tensors across all processes, and then scatters the reduced results, so each process receives a portion of the final reduced tensor. This can be more efficient than separate Reduce and Scatter operations in some scenarios.
Visualizing these operations can aid understanding. Consider a simple All Reduce operation for summing gradients in a 4 process setup:
Each process computes its local gradient (Gi). During the All-Reduce step, these gradients are communicated and summed across all processes. The final summed gradient is then made available back on every process, ready for the optimizer step.
These fundamental concepts directly map to the distributed training techniques discussed later in this chapter:
All-Reduce
to average gradients calculated on different data batches across workers. Broadcast
is used initially to synchronize model weights.send
/recv
primitives, not detailed here but part of torch.distributed
) or specific collective operations like Scatter
and Gather
between specific ranks holding different parts of the model or processing different micro batches.All-Gather
to reconstruct full parameters for the forward/backward pass within a layer, and Reduce-Scatter
to average gradients and shard them back efficiently across workers.Understanding these building blocks, nodes, processes, ranks, backends, and collective communication patterns, provides a solid foundation for effectively implementing and troubleshooting distributed training workflows in PyTorch. With this vocabulary established, we can proceed to examine how PyTorch orchestrates these elements for different parallelization strategies.
© 2025 ApX Machine Learning