When a neural network becomes too large to fit even its individual layers onto a single device (a scenario not addressed by DistributedDataParallel
), or when we want to overlap computation and communication differently, Pipeline Parallelism offers an alternative scaling strategy. Instead of replicating the model (like DDP) or splitting individual layers (like Tensor Parallelism), Pipeline Parallelism partitions the model itself sequentially across multiple devices. Each device, or group of devices, becomes a "stage" in the pipeline, responsible for executing a subset of the model's layers.
Imagine a model composed of several sequential layers or blocks. In pipeline parallelism, you assign consecutive blocks to different devices. For example, in a four-layer model running on four GPUs:
Input data enters the first stage (GPU 0). After processing, the output activation is sent to the second stage (GPU 1). This continues until the final stage computes the output and the loss. The gradients then flow backward through the pipeline in reverse order. GPU 3 calculates gradients for Layer 4 and sends the gradient of Layer 3's output back to GPU 2, which then computes gradients for Layer 3 and sends gradients back to GPU 1, and so on, until the gradients reach the first stage.
A naive implementation of this process is inefficient. Consider the timeline: while Stage 1 processes the first data batch, Stage 0 is idle, waiting for the next batch. Similarly, while Stage 2 processes, Stages 0 and 1 are idle (assuming a single batch flows through). During the backward pass, stages become idle again as they wait for gradients from the subsequent stage. This idle time, known as the "pipeline bubble," significantly reduces hardware utilization.
A naive pipeline execution with a single batch, showing significant idle periods (bubbles) on GPUs during forward (Fwd) and backward (Bwd) passes across time steps (T1-T8).
The standard solution to the pipeline bubble problem is micro-batching. Instead of feeding the entire mini-batch through the pipeline at once, we split it into smaller chunks called micro-batches. The pipeline processes these micro-batches concurrently.
As soon as Stage 0 finishes processing the first micro-batch and sends its activations to Stage 1, Stage 0 can immediately start processing the second micro-batch. This allows multiple micro-batches to be "in flight" within the pipeline simultaneously, overlapping computation across stages and significantly reducing idle time. The number of micro-batches (m) is a hyperparameter; a larger m generally leads to better utilization but increases communication overhead and potentially memory usage due to storing intermediate activations and gradients for each micro-batch.
Pipeline execution with micro-batching (m0-m3). Forward (F) and backward (B) passes for different micro-batches overlap across stages (GPUs), reducing idle time compared to the naive approach. The initial fill and final drain phases still have some bubbles.
PyTorch offers tools to implement pipeline parallelism, although it often requires more manual setup than DDP. The core components involve:
nn.Module
into sequential nn.Sequential
blocks, one for each stage. Place each block onto its designated device.torch.distributed.send
and torch.distributed.recv
operations to transfer activations forward and gradients backward between adjacent stages. Remember that these operations are blocking.Let's outline the conceptual flow for a simple two-stage pipeline (GPU 0 and GPU 1) with micro-batching:
import torch
import torch.nn as nn
import torch.distributed as dist
# Assume distributed environment is initialized (rank 0 on GPU 0, rank 1 on GPU 1)
# Assume model is split into stage0 and stage1, placed on respective devices
def run_pipeline_step(stage0, stage1, micro_batches_data, micro_batches_labels, loss_fn, optimizer):
num_micro_batches = len(micro_batches_data)
activations_storage = [None] * num_micro_batches # To store activations for backward pass
gradients_storage = [None] * num_micro_batches # To store gradients for backward pass
current_rank = dist.get_rank()
world_size = dist.get_world_size() # Assume world_size = 2 for this example
# --- Forward Pass ---
for i in range(num_micro_batches):
micro_batch = micro_batches_data[i]
if current_rank == 0: # First stage
# Compute activations for stage 0
activations = stage0(micro_batch.to(current_rank))
# Send activations to the next stage (rank 1)
dist.send(activations.cpu(), dst=1, tag=i) # Send CPU tensor to avoid GPU sync issues
activations_storage[i] = activations # Store for backward pass
elif current_rank == 1: # Last stage
# Receive activations from the previous stage (rank 0)
received_activations = torch.empty_like(some_prototype_tensor_shape, device='cpu') # Need shape info
dist.recv(received_activations, src=0, tag=i)
received_activations = received_activations.to(current_rank)
received_activations.requires_grad_() # IMPORTANT: Enable grad for received tensor
# Compute activations for stage 1 (final output)
outputs = stage1(received_activations)
# Compute loss
labels = micro_batches_labels[i].to(current_rank)
loss = loss_fn(outputs, labels)
# Store necessary info for backward pass
activations_storage[i] = received_activations # Input to this stage
gradients_storage[i] = loss # Store loss to initiate backward later
# --- Backward Pass ---
# Iterate backward through micro-batches for correctness (GPipe schedule)
for i in range(num_micro_batches - 1, -1, -1):
if current_rank == 1: # Last stage
loss = gradients_storage[i]
input_activation = activations_storage[i]
# Initiate backward pass for this micro-batch's loss
# Gradients computed locally for stage1 parameters
# Need to retain graph if not the last micro-batch overall for the stage input
retain_graph_flag = (i != 0)
loss.backward(retain_graph=retain_graph_flag)
# Send gradient of input activation back to previous stage (rank 0)
grad_to_send = input_activation.grad.cpu()
dist.send(grad_to_send, dst=0, tag=i)
elif current_rank == 0: # First stage
# Receive gradient from the next stage (rank 1)
grad_received = torch.empty_like(some_prototype_grad_shape, device='cpu') # Need shape info
dist.recv(grad_received, src=1, tag=i)
grad_received = grad_received.to(current_rank)
# Continue backward pass using received gradient
output_activation = activations_storage[i]
# Gradients computed locally for stage0 parameters
output_activation.backward(gradient=grad_received)
# --- Optimizer Step ---
# After processing all micro-batches, update weights
optimizer.step()
optimizer.zero_grad()
# --- Notes ---
# 1. This is a simplified conceptual example (GPipe-style schedule).
# 2. Need mechanisms to determine tensor shapes for recv buffers.
# 3. Error handling, proper device placement, and synchronization are crucial.
# 4. More advanced schedules (e.g., interleaved) exist.
# 5. Libraries like `torch.distributed.pipeline` (experimental) aim to simplify this.
This manual implementation highlights the complexity involved: explicit communication calls, managing intermediate activations and their gradients for each micro-batch, and careful synchronization between stages.
PyTorch also has an experimental torch.distributed.pipeline.sync.Pipe
module designed to abstract some of this complexity, automatically handling micro-batching, communication, and gradient propagation based on a model definition split into stages. However, understanding the manual process using primitives provides valuable insight into the underlying operations.
Pipeline parallelism is most beneficial for extremely large models where even tensor parallelism isn't sufficient or when fine-grained control over device execution and memory is required. It's often combined with data parallelism (e.g., running DDP within each pipeline stage) for further scaling.
© 2025 ApX Machine Learning