Masterclass
Distributing the computational load and model parameters across multiple devices is necessary for training large models, but it introduces a significant performance consideration: communication overhead. Every time data needs to be exchanged between devices, whether it's gradients, activations, or weight shards, time is spent on communication rather than computation. Minimizing this overhead is critical for achieving efficient scaling and reducing training time.
This section analyzes the communication patterns and costs associated with the different parallelism strategies we've discussed. Understanding these costs helps in choosing the most suitable strategy or combination of strategies for a given model architecture and hardware setup.
Distributed training relies on communication collectives, which are operations involving multiple processes coordinated to exchange data. The most common primitives used in LLM training include:
broadcast
): Sends data from one process to all other processes.reduce
): Combines data from all processes onto one process using a specified operation (e.g., sum, average).all_reduce
): Combines data from all processes and distributes the result back to all processes. This is effectively a Reduce followed by a Broadcast. It's heavily used in Data Parallelism for synchronizing gradients.scatter
): Distributes chunks of data from one process to all other processes.gather
): Collects chunks of data from all processes onto one process.all_gather
): Collects chunks of data from all processes and distributes the complete, concatenated data back to all processes. Used in some Tensor Parallelism implementations.reduce_scatter
): Combines data from all processes using a reduction operation, and then scatters the results, so each process receives a chunk of the final reduced tensor. Also used in Tensor Parallelism.send
/recv
): Direct communication where one process sends data to another specific process, which receives it. This is the primary mechanism for Pipeline Parallelism.In PyTorch, these operations are typically accessed via the torch.distributed
package. For instance, performing an All-Reduce operation on a tensor t
across all processes in a group might look like this:
import torch
import torch.distributed as dist
# Assume 't' is a tensor on the current device
# Assume distributed environment is initialized
# Perform asynchronous All-Reduce (summation by default)
dist.all_reduce(t, op=dist.ReduceOp.SUM, async_op=True)
# Later, synchronize if needed, or chain computations
# ...
Point-to-point communication for Pipeline Parallelism involves pairs of sending and receiving processes:
# Process in stage 'i' sends activations to stage 'i+1'
if rank == i:
# Assume 'activations' is the tensor to send
dist.send(tensor=activations, dst=i+1)
# Process in stage 'i+1' receives activations from stage 'i'
elif rank == i+1:
# Allocate buffer for receiving activations
received_activations = torch.zeros_like(expected_activations_shape)
dist.recv(tensor=received_activations, src=i)
The time taken for communication depends on several factors:
A common simple model for communication time T for a single message of size M is the alpha-beta model:
T≈α+βeff​M​Here, α represents the latency component, and βeff​ is the effective bandwidth achieved for the transfer. For collective operations involving P devices, the model becomes more complex, often involving terms logarithmic or linear in P, depending on the algorithm.
Let's analyze the communication costs inherent in each primary strategy:
Data Parallelism (DP):
all_reduce
operation after the backward pass to sum gradients across all P replicas.Tensor Parallelism (TP):
all_reduce
, all_gather
, or reduce_scatter
operations within the forward and backward passes of specific layers (e.g., MLP or Attention blocks) that are split across devices.A simplified view of data flow in a 4-GPU Ring All-Reduce, often used in DP or TP collectives. Each GPU sends and receives chunks of data from its neighbor.
Pipeline Parallelism (PP):
send
/recv
operations between adjacent pipeline stages. Stage i sends its output activations to stage i+1 during the forward pass, and stage i+1 sends gradients of the activations back to stage i during the backward pass.Strategy | Primary Operation(s) | Message Size | Frequency | Sensitivity | Bottleneck(s) |
---|---|---|---|---|---|
Data Parallelism | all_reduce |
Model Gradients (Large) | Once per (accumulated) step | Bandwidth | All-Reduce time |
Tensor Parallelism | all_reduce , all_gather , reduce_scatter |
Layer Activations/Grads (Small/Medium) | Multiple per layer | Latency, Bandwidth | Frequent collective calls |
Pipeline Parallelism | send /recv |
Boundary Activations/Grads (Medium/Large) | Once per micro-batch / stage | Latency, Bandwidth | Pipeline bubble, Inter-stage comms |
Table: Qualitative comparison of communication characteristics for different parallelism strategies.
Hybrid approaches combine these strategies, leading to more complex communication patterns. For example, using DP combined with TP means each DP group (where TP is applied) performs an All-Reduce for gradients. Using TP and PP together involves intra-stage TP communication and inter-stage PP communication.
While theoretical analysis provides intuition, the actual communication overhead in a specific training run depends heavily on the implementation details, hardware, network configuration, and software stack (e.g., PyTorch version, NCCL version). Therefore, profiling is essential. Tools like torch.profiler
, NVIDIA Nsight Systems (nsys
), or framework-specific logging can help measure the time spent in different communication operations (nccl:all_reduce
, nccl:send
, etc.) versus computation kernels. Analyzing these profiles is critical for identifying bottlenecks and optimizing the distributed training configuration.
# Example using torch.profiler to capture CPU and GPU activity
# including distributed communication calls (if using NCCL backend)
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA
],
record_shapes=True,
profile_memory=True,
with_stack=True
) as prof:
with torch.profiler.record_function("model_training_step"):
# Your model forward, backward, and optimizer step here
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
# Print aggregated statistics
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
# Export trace for detailed analysis in tools like Perfetto UI
# or Chrome Trace Viewer
# prof.export_chrome_trace("trace.json")
By understanding the fundamental communication patterns and costs associated with each parallelism strategy, and by using profiling tools to measure real-world performance, you can make informed decisions about how to best distribute your LLM training workload to maximize efficiency and minimize training time.
Was this section helpful?
© 2025 ApX Machine Learning