While high-level abstractions like DistributedDataParallel
(DDP) handle many details of distributed training automatically, understanding the lower-level communication primitives provided by the torch.distributed
package offers deeper insight and enables the implementation of custom parallelization strategies. These primitives are the fundamental building blocks for orchestrating communication between different processes in a distributed setting.
Before using any communication primitives, the distributed environment must be initialized, typically using torch.distributed.init_process_group
. This establishes the communication backend (like NCCL or Gloo) and assigns a unique rank
to each process within the total world_size
. Once initialized, processes within the default group (or a custom-created group) can coordinate using collective and point-to-point operations.
Collective operations involve communication among all processes within a group. They are essential for tasks like synchronizing gradients or distributing model parameters. Here are some of the most frequently used collectives:
dist.broadcast
)This operation sends a tensor from a single source process (src
) to all other processes in the group. It's commonly used to ensure all processes start with the same initial model parameters.
import torch
import torch.distributed as dist
import os
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# Initialize the process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)
def run_broadcast(rank, world_size):
setup(rank, world_size)
tensor = torch.zeros(1)
if rank == 0:
# Source process creates the data
tensor += 1
# Rank 0 broadcasts 'tensor' to all other processes
dist.broadcast(tensor=tensor, src=0)
print(f"Rank {rank} has data: {tensor[0]}")
dist.destroy_process_group()
# Assume world_size = 4 for demonstration
# In a real script, this would be launched via torchrun or similar
# run_broadcast(0, 4)
# run_broadcast(1, 4)
# run_broadcast(2, 4)
# run_broadcast(3, 4)
After this operation, the tensor
on ranks 1, 2, and 3 will be updated from 0
to 1
.
Data flow during a
dist.broadcast
operation from rank 0 to all other ranks in a 4-process group.
dist.all_reduce
)This operation combines tensors from all processes using a specified reduction operation (op
, e.g., dist.ReduceOp.SUM
, dist.ReduceOp.AVG
) and distributes the final result back to all processes. This is the cornerstone of gradient synchronization in DDP. Each process contributes its local gradient, they are summed (or averaged) across all processes, and every process receives the combined gradient.
import torch
import torch.distributed as dist
import os
# setup function assumed to be defined as above
def run_all_reduce(rank, world_size):
setup(rank, world_size)
# Each rank creates data based on its rank
tensor = torch.tensor([rank + 1], dtype=torch.float32)
print(f"Rank {rank} initial tensor: {tensor[0]}")
# Perform all-reduce with SUM operation
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
# Result (sum of 1+2+3+4 = 10) is available on all ranks
print(f"Rank {rank} final tensor: {tensor[0]}")
dist.destroy_process_group()
# Example execution for world_size = 4
# run_all_reduce(0, 4) # Initial: 1, Final: 10
# run_all_reduce(1, 4) # Initial: 2, Final: 10
# run_all_reduce(2, 4) # Initial: 3, Final: 10
# run_all_reduce(3, 4) # Initial: 4, Final: 10
Flow of
dist.all_reduce
with SUM operation. All ranks contribute data, it's aggregated, and the result is distributed back to all ranks.
dist.reduce
)Similar to all_reduce
, reduce
combines tensors from all processes using a reduction operation. However, the result is only stored on the destination process (dst
). Other processes do not receive the result.
dist.scatter
)This operation takes a list of tensors (scatter_list
) on a single source process (src
) and distributes one tensor from the list to each process in the group, including itself. The i-th tensor in scatter_list
is sent to the process with rank i. This is useful for distributing batches of data across processes.
import torch
import torch.distributed as dist
import os
# setup function assumed to be defined as above
def run_scatter(rank, world_size):
setup(rank, world_size)
my_tensor = torch.zeros(1)
scatter_list = None
if rank == 0:
# Source rank prepares the list of tensors to scatter
scatter_list = [torch.tensor([i + 1.0]) for i in range(world_size)]
print(f"Rank 0 scatter list: {[t.item() for t in scatter_list]}")
# Rank 0 scatters the list. Each rank receives one tensor into my_tensor.
dist.scatter(tensor=my_tensor, scatter_list=scatter_list, src=0)
print(f"Rank {rank} received tensor: {my_tensor.item()}")
dist.destroy_process_group()
# Example execution for world_size = 4
# run_scatter(0, 4) # Received: 1.0
# run_scatter(1, 4) # Received: 2.0
# run_scatter(2, 4) # Received: 3.0
# run_scatter(3, 4) # Received: 4.0
Data flow for
dist.scatter
. Rank 0 holds a list of tensors [A, B, C, D] and sends A to Rank 0, B to Rank 1, C to Rank 2, and D to Rank 3.
dist.gather
)The inverse of scatter
. Each process sends its tensor to a single destination process (dst
). The destination process receives these tensors and stores them in a list (gather_list
). The order in the gather_list
corresponds to the rank of the sending process.
import torch
import torch.distributed as dist
import os
# setup function assumed to be defined as above
def run_gather(rank, world_size):
setup(rank, world_size)
# Each rank creates its own tensor
my_tensor = torch.tensor([rank + 1.0])
gather_list = None
if rank == 0:
# Destination rank prepares a list to store gathered tensors
gather_list = [torch.zeros(1) for _ in range(world_size)]
# All ranks send their tensor to rank 0
dist.gather(tensor=my_tensor, gather_list=gather_list, dst=0)
if rank == 0:
print(f"Rank 0 gathered list: {[t.item() for t in gather_list]}")
else:
print(f"Rank {rank} sent tensor: {my_tensor.item()}")
dist.destroy_process_group()
# Example execution for world_size = 4
# run_gather(0, 4) # Gathered: [1.0, 2.0, 3.0, 4.0]
# run_gather(1, 4) # Sent: 2.0
# run_gather(2, 4) # Sent: 3.0
# run_gather(3, 4) # Sent: 4.0
Data flow for
dist.gather
. Ranks 0, 1, 2, 3 send their tensors A, B, C, D respectively to Rank 0, which collects them into a list [A, B, C, D].
dist.all_gather
)Similar to gather
, but the resulting list of tensors gathered from all processes is distributed back to all processes in the group. Each process receives the same final list.
These operations involve communication between two specific processes, identified by their ranks.
dist.send(tensor, dst)
: Sends a tensor from the current process to the destination process (dst
). This is a blocking operation on the sending side.dist.recv(tensor, src)
: Receives a tensor into the provided tensor
buffer from the source process (src
). This is blocking on the receiving side until the tensor is received.While powerful, point-to-point operations require careful management to avoid deadlocks (e.g., two processes waiting to receive from each other before sending). They are less commonly used directly for standard data-parallel training compared to collectives but are important for more complex communication patterns like model parallelism or custom algorithms.
Most collective operations (broadcast
, all_reduce
, scatter
, gather
, etc.) are blocking (synchronous) by default. This means program execution on a process pauses until that process has completed its part in the collective communication.
PyTorch also provides non-blocking (asynchronous) versions of many operations, typically prefixed with i
(e.g., dist.isend
, dist.irecv
, dist.all_reduce(..., async_op=True)
). These calls initiate the communication and return immediately with a Work
object (or similar handle). The program can continue executing other tasks while the communication happens in the background. You can later check for completion or wait for the operation to finish using methods like wait()
on the returned handle.
# Example of non-blocking all-reduce
tensor = torch.ones(1) * rank
# ... other setup ...
# Initiate non-blocking all-reduce
work_handle = dist.all_reduce(tensor, op=dist.ReduceOp.SUM, async_op=True)
# Perform other computations while communication happens...
# result = compute_something_else()
# Wait for the all_reduce operation to complete
work_handle.wait()
# Now 'tensor' contains the reduced result
print(f"Rank {rank} async all_reduce result: {tensor[0]}")
Using non-blocking operations can significantly improve performance by overlapping computation with communication, especially on systems with fast interconnects. However, it requires careful management of dependencies and synchronization points.
Understanding these torch.distributed
primitives provides the foundation for implementing sophisticated distributed training workflows beyond the standard DDP pattern. They allow fine-grained control over inter-process communication, which is necessary for techniques like pipeline parallelism, custom gradient aggregation schemes, or interacting with specialized hardware communication libraries.
© 2025 ApX Machine Learning