While Data Parallelism replicates the model across devices and processes different data batches, it doesn't help when the model itself is too large to fit into the memory of a single accelerator (GPU). For such scenarios, we need Model Parallelism, which involves partitioning a single model across multiple devices. Tensor Model Parallelism (TMP) is a specific type of model parallelism where individual layers or even specific operations within layers are split across devices.
This approach is particularly relevant for models with massive parameter counts concentrated in specific layers, such as the large embedding tables or the feed-forward network (FFN) layers found in modern Transformer architectures.
The fundamental principle of Tensor Model Parallelism is to split the weight tensors (and sometimes activation tensors) of a layer across multiple GPUs in a coordinated manner. The computation is then performed partially on each GPU, followed by communication steps to synchronize or combine the results, ultimately yielding the same output as the original, unsplit layer.
Let's consider a standard linear layer, defined by the operation Y=XA+b, where X is the input activation, A is the weight matrix, b is the bias, and Y is the output activation. TMP offers different strategies to parallelize this operation.
In column parallelism, the weight matrix A is split column-wise across N GPUs. Let A=[A1,A2,...,AN], where each Ai resides on a different GPU.
Column parallelism applied to a linear layer. Input X is shared, weights A and bias b are split column-wise, and partial outputs are concatenated.
This approach requires communication (a gather or all-gather operation) after the parallel computation to assemble the full output tensor Y.
Alternatively, we can split the weight matrix A row-wise: A=A1A2⋮AN.
Row parallelism applied to a linear layer. Input X is split, weights A are split row-wise, partial results are computed, and then summed via all-reduce. Bias b is added after the reduction.
Row parallelism requires communication (an all-reduce operation) after the matrix multiplication. A key benefit is that the output activation Y is replicated across all participating GPUs, which might be the required input format for a subsequent layer (e.g., a layer norm or another row-parallel linear layer).
Transformer models often utilize a combination of these techniques. For instance, within a standard Feed-Forward Network (FFN) block consisting of two linear layers:
This strategic combination helps minimize communication overhead by keeping activations sharded between the two linear layers within the FFN block.
For models with extremely large vocabulary sizes, the embedding table can become a significant memory bottleneck. Tensor Model Parallelism can be applied here by splitting the embedding table row-wise (along the vocabulary dimension) across the GPUs. When looking up embeddings for input token IDs:
Implementing Tensor Model Parallelism manually requires careful handling of tensor sharding, computation, and synchronization using primitives from the torch.distributed
package:
torch.distributed.broadcast
: Sends a tensor from one rank to all others.torch.distributed.scatter
: Scatters chunks of a tensor across ranks.torch.distributed.gather
: Gathers tensors from all ranks onto one rank.torch.distributed.all_gather
: Gathers tensors from all ranks onto all ranks.torch.distributed.reduce_scatter
: Performs an operation (like sum) across ranks and scatters the result.torch.distributed.all_reduce
: Performs an operation (like sum) across ranks and makes the result available on all ranks.Specialized functions or wrappers are often created to encapsulate these operations for specific layer types (e.g., ColumnParallelLinear
, RowParallelLinear
). Libraries like NVIDIA's Megatron-LM pioneered many of these techniques, and parts of this functionality are being integrated into PyTorch core via modules like torch.distributed.tensor.parallel
which provide higher-level APIs to simplify these implementations.
Consider a conceptual example for a column-parallel linear layer using hypothetical helper functions:
import torch
import torch.nn as nn
import torch.distributed as dist
# Assume these helper functions exist for managing parallel groups and communication
from .parallel_utils import (
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
copy_to_tensor_model_parallel_region, # Handles input broadcasting/splitting
gather_from_tensor_model_parallel_region # Handles output gathering/reduction
)
class ColumnParallelLinear(nn.Module):
def __init__(self, input_size, output_size, bias=True, **kwargs):
super().__init__()
world_size = get_tensor_model_parallel_world_size()
# Ensure output_size is divisible by world_size
assert output_size % world_size == 0
self.output_size_per_partition = output_size // world_size
self.input_size = input_size
# Weight matrix is split along the output dimension (columns)
self.weight = nn.Parameter(torch.empty(
self.output_size_per_partition, self.input_size, **kwargs
))
# Initialize weights... (e.g., using init.kaiming_uniform_)
if bias:
# Bias is also split along the output dimension
self.bias = nn.Parameter(torch.empty(
self.output_size_per_partition, **kwargs
))
# Initialize bias... (e.g., using init.zeros_)
else:
self.register_parameter('bias', None)
def forward(self, input_):
# Input might need to be broadcast or is already available if
# the previous layer's output was replicated (e.g., LayerNorm).
# This function handles the necessary communication.
parallel_input = copy_to_tensor_model_parallel_region(input_)
# Perform local matrix multiplication
output_parallel = nn.functional.linear(parallel_input, self.weight, self.bias)
# Gather results from all GPUs in the tensor parallel group
# Concatenates along the column dimension.
output_ = gather_from_tensor_model_parallel_region(output_parallel)
return output_
Implementation sketch of a Column Parallel Linear layer. Note the explicit sharding of weights/bias and the use of hypothetical communication wrappers.
In summary, Tensor Model Parallelism is an indispensable technique when individual model layers exceed single-device memory limits. By splitting weights and computation within layers across multiple devices, it enables the training of models far larger than would otherwise be possible, albeit at the cost of increased implementation complexity and communication overhead. It's a fundamental building block for scaling today's largest neural networks.
© 2025 ApX Machine Learning