As discussed in the chapter introduction, the computational bottleneck of the standard Transformer lies in the self-attention mechanism. Calculating attention weights requires comparing every token (query) with every other token (key) in the sequence. This results in a computation and memory complexity of O(N2d), where N is the sequence length and d is the model dimension. While often simplified to O(N2) when d is treated as a constant, this quadratic scaling makes processing very long sequences (thousands or tens of thousands of tokens) computationally prohibitive.
Sparse attention mechanisms aim to alleviate this bottleneck by reducing the number of query-key pairs that need to be computed. Instead of allowing each token to attend to all other tokens, sparse attention restricts attention to a carefully chosen subset of tokens, forming a sparse attention matrix. The core assumption is that full attention is often redundant, and most of the relevant information for a given token can be captured by attending to a smaller, strategically selected set of other tokens.
Empirical analysis often reveals that learned attention matrices in well-trained Transformers are indeed sparse. Many attention weights are close to zero, indicating that a token only strongly attends to a limited number of other tokens. Sparse attention methods try to exploit this observation by predefining or learning patterns that focus computation on potentially important relationships, drastically reducing the N2 comparisons.
Several structured sparsity patterns have been proposed, often combining different types of attention to balance local context with broader, long-range interactions.
The most intuitive pattern is local or sliding window attention. Here, each token only attends to a fixed-size window of k neighboring tokens (e.g., k/2 tokens to the left and k/2 tokens to the right).
Local attention pattern where token 4 attends to itself and its neighbors within a window size (k=4 here, considering 2 left, 2 right, plus self). Computation is restricted to these neighbors.
This reduces the complexity per token from O(N) to O(k), leading to an overall complexity of O(N⋅k). Since k is a fixed hyperparameter typically much smaller than N, this is effectively linear in N. However, the obvious limitation is that information cannot propagate beyond the window size k within a single attention layer. Capturing long-range dependencies requires stacking many such layers.
To mitigate the limited receptive field of local attention without sacrificing efficiency, dilated or strided attention can be used. Similar to dilated convolutions, a token attends to neighbors with increasing gaps or strides. For example, a token might attend to positions i±1, i±2, i±4, i±8, and so on.
Dilated (strided) attention pattern where token 5 attends to itself, immediate neighbors (gap 1), neighbors with gap 2, and neighbors with gap 4. This allows covering a wider receptive field with a fixed number of computations per token.
This allows the receptive field to grow exponentially with the number of layers, enabling the capture of longer dependencies more efficiently than purely local attention, while maintaining a complexity like O(N⋅kdilated), where kdilated is the number of positions attended to (often related logarithmically to N).
Many successful sparse attention models, such as Longformer and BigBird, combine local attention with a few pre-selected global
tokens. These global tokens can attend to all other tokens in the sequence, and all other tokens can attend to them.
[CLS]
, or tokens identified as particularly important) have full attention capabilities. This ensures that information can be aggregated globally and distributed back to local contexts.This hybrid approach aims to get the best of both worlds: linear scaling for most tokens via local attention, plus the ability to capture critical long-range dependencies via the global tokens. The complexity is typically dominated by the local attention part, remaining close to O(N⋅k), assuming the number of global tokens is small relative to N.
Implementing sparse attention often requires more complex index manipulation compared to the dense matrix multiplications of standard attention. Instead of computing the full N×N attention matrix, specific indices corresponding to the allowed sparse connections must be gathered or computed. Libraries and frameworks increasingly provide optimized implementations for common sparse patterns (e.g., block-sparse operations).
Advantages:
Disadvantages:
Sparse attention represents a significant step towards making Transformers applicable to tasks involving very long documents, high-resolution images (treated as sequences of patches), or extended time series. It's a prime example of modifying the core architecture to overcome inherent scaling limitations, a theme we will continue exploring with approximation techniques in the following sections.
© 2025 ApX Machine Learning