Masterclass
While Data Parallelism (DP) effectively utilizes multiple devices by replicating the model and processing different data slices, it requires each device to hold the entire model. For truly massive models, even a single replica might exceed the memory capacity of one accelerator. Furthermore, some individual operations within the model, like the large linear transformations in feed-forward networks or attention mechanisms, can become computational bottlenecks. This is where Tensor Parallelism (TP), sometimes called intra-layer model parallelism, becomes necessary.
Instead of splitting the data or the sequence of layers, Tensor Parallelism splits the actual tensors (weights, activations, gradients) within a specific operation across multiple devices. This allows computations on these large tensors to be performed in parallel, distributing both the memory footprint and the computational load of individual layers.
The most straightforward application of TP is in large linear layers, which are fundamental components of both the multi-layer perceptron (MLP) and attention blocks in Transformers. Consider a linear transformation Y=XA, where X is the input activation and A is the weight matrix. We can parallelize this matrix multiplication across, for example, two devices in two primary ways: column parallelism and row parallelism.
In column parallelism, the weight matrix A is split vertically (along its columns) across devices. If we have two devices, we split A into A=[A1​,A2​]. The input X is typically broadcast or made available to both devices. Each device then computes a part of the output:
The final output Y is obtained by concatenating the results along the column dimension: Y=[Y1​,Y2​].
Column parallelism for Y=XA. The weight matrix A is split into A1​ and A2​. The input X is multiplied by each part on separate GPUs. The results Y1​ and Y2​ are concatenated to form the final output Y.
Mathematically, the forward pass involves XA=X[A1​,A2​]=[XA1​,XA2​]. The backward pass requires gradients with respect to X and A. The gradient ∂A∂L​ is naturally partitioned (∂A1​∂L​, ∂A2​∂L​). However, computing the gradient ∂X∂L​ requires summing contributions from both paths: ∂X∂L​=∂Y1​∂L​A1T​+∂Y2​∂L​A2T​. This summation is typically achieved using an all-reduce communication collective across the devices holding A1​ and A2​.
In row parallelism, the weight matrix A is split horizontally (along its rows). For two devices, A=[A1​A2​​]. The input X is also considered partitioned along its columns (often because it's the output of a preceding column-parallel layer). Let's consider the input X available on both devices for simplicity in this isolated view. Each device computes a partial result based on its slice of the weights:
Unlike column parallelism, the final output Y is the sum of the partial results: Y=Y1​+Y2​.
Row parallelism for Y=XA. The weight matrix A is split into A1​ and A2​ (row-wise). Partial results Y1​ and Y2​ are computed on separate GPUs. An all-reduce operation sums these results to produce the final output Y.
Mathematically, Y=X[A1​A2​​] isn't standard matrix multiplication notation if X is treated as a single block. The actual operation in context (e.g., following a column-split layer) is Y=X1​A1​+X2​A2​, where X=[X1​,X2​]. The forward pass requires an all-reduce operation to perform this summation across devices. The backward pass for ∂X∂L​ involves partitioned gradients based on the partitioned A: ∂X1​∂L​=∂Y∂L​A1T​ and ∂X2​∂L​=∂Y∂L​A2T​. The gradient ∂A∂L​ is directly computed on each device for its corresponding weight slice.
Tensor Parallelism is typically applied strategically within Transformer blocks to balance computation and minimize communication overhead. A common pattern, popularized by frameworks like Megatron-LM, combines column and row parallelism within the MLP block and applies similar partitioning to the attention mechanism.
MLP Block: A standard Transformer MLP block computes Y=Activation(XA)B+X (including residual connection). TP is applied as follows:
This combination cleverly arranges the parallelism such that the communication required for the forward pass (all-reduce after the second linear layer) and the backward pass (all-reduce for ∇X after the first linear layer) do not overlap unnecessarily, optimizing the flow.
Attention Block: TP can also be applied to the self-attention mechanism.
The implementation details within the attention block can be intricate, involving optimized kernels and communication patterns to handle the sequence length and head dimensions efficiently.
A significant drawback of Tensor Parallelism is the increased communication overhead compared to Data Parallelism. While DP typically involves one all-reduce per training step (for gradients), TP introduces communication within the forward and backward passes of each Transformer block.
These communication operations (like all-reduce) involve synchronizing and exchanging data across all devices participating in the TP group. The volume of data exchanged depends on the size of the activations or gradients being communicated. This communication cost scales with the number of devices in the TP group and can become a performance bottleneck if the interconnect bandwidth between devices (e.g., NVLink, InfiniBand) is insufficient relative to the computation speed.
However, consider these points:
Here's a PyTorch-style snippet illustrating the idea of splitting a weight matrix for column parallelism (this is highly simplified and omits communication and gradient handling):
import torch
import torch.nn as nn
# Assume world_size is the number of GPUs for TP
# Assume rank is the current GPU's rank in the TP group
class ColumnParallelLinear(nn.Module):
def __init__(self, input_size, output_size, world_size, rank):
super().__init__()
self.input_size = input_size
# Each GPU handles a slice of the output features
self.output_size_per_partition = output_size // world_size
self.world_size = world_size
self.rank = rank
# Initialize only the part of the weight matrix for this rank
self.weight = nn.Parameter(
torch.randn(self.output_size_per_partition, self.input_size)
)
# Bias is also partitioned
self.bias = nn.Parameter(
torch.randn(self.output_size_per_partition)
)
def forward(self, x):
# Assume input x is available on all GPUs (broadcast or result of
# all-reduce)
# Matrix multiplication happens only on the local partition of the weight
# Output_partition = X * A_partition^T + b_partition
# Note: PyTorch linear layers expect weight as
# (out_features, in_features)
output_partition = nn.functional.linear(x, self.weight, self.bias)
# In a real implementation, this output_partition would need to be
# gathered across GPUs if the next layer isn't row-parallel.
# For the Megatron-LM MLP pattern (Col->Row), no forward comms needed
# here.
# Backward pass communication (all-reduce for grad_X) is handled
# by custom autograd functions in libraries like Megatron-LM.
return output_partition # Returns only the slice computed by this GPU
In summary, Tensor Parallelism is a powerful technique for distributing the computation and memory of individual large layers across multiple devices. While it introduces significant communication overhead, it is often a necessary component, alongside Data and Pipeline Parallelism, for training state-of-the-art large language models that push the boundaries of single-accelerator capabilities. Frameworks like Megatron-LM and DeepSpeed abstract away much of the complexity, providing optimized building blocks for implementing TP effectively.
© 2025 ApX Machine Learning