如前所述,标准的自注意力机制虽然强大,但会带来显著的计算负担,尤其当输入序列变长时。了解这种计算成本,有助于把握本章后续会介绍的更高效架构的设计初衷。
缩放点积注意力公式中的主要运算构成了Transformer中自注意力机制的基本形式。
注意力 ( Q , K , V ) = softmax ( Q K T d k ) V \text{注意力}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V 注意力 ( Q , K , V ) = softmax ( d k Q K T ) V
这里,Q Q Q (查询)、K K K (键) 和 V V V (值) 是从输入序列嵌入中得到的矩阵。设 N N N 为序列长度,d k d_k d k 为键和查询的维度,d v d_v d v 为值的维度。
主要计算步骤如下:
查询-键相似度计算: 计算矩阵乘积 Q K T QK^T Q K T 。
Q Q Q 的维度为 ( N × d k ) (N \times d_k) ( N × d k ) 。
K T K^T K T 的维度为 ( d k × N ) (d_k \times N) ( d k × N ) 。
得到的注意力得分矩阵 S = Q K T S = QK^T S = Q K T 的维度为 ( N × N ) (N \times N) ( N × N ) 。
此矩阵乘法的计算成本大约为 O ( N ⋅ d k ⋅ N ) = O ( N 2 d k ) O(N \cdot d_k \cdot N) = O(N^2 d_k) O ( N ⋅ d k ⋅ N ) = O ( N 2 d k ) 浮点运算 (FLOPs)。
缩放与Softmax: 将得分按 1 / d k 1/\sqrt{d_k} 1/ d k 进行缩放,并对每行应用softmax函数。
缩放涉及 N 2 N^2 N 2 次元素级乘法。
Softmax涉及对每个 N N N 行进行指数运算和归一化,每行所需的运算量与行长 N N N 成正比。总成本大约为 O ( N 2 ) O(N^2) O ( N 2 ) 。
与矩阵乘法相比,对于较大的 N N N 和 d k d_k d k ,此步骤的计算量通常较小。
值聚合: 将softmax输出 (即注意力权重矩阵 A A A ,维度为 ( N × N ) (N \times N) ( N × N ) ) 与值矩阵 V V V 相乘。
A A A 的维度为 ( N × N ) (N \times N) ( N × N ) 。
V V V 的维度为 ( N × d v ) (N \times d_v) ( N × d v ) 。
得到的输出矩阵 O = A V O = AV O = A V 的维度为 ( N × d v ) (N \times d_v) ( N × d v ) 。
此矩阵乘法的计算成本大约为 O ( N ⋅ N ⋅ d v ) = O ( N 2 d v ) O(N \cdot N \cdot d_v) = O(N^2 d_v) O ( N ⋅ N ⋅ d v ) = O ( N 2 d v ) 浮点运算 (FLOPs)。
主要影响因素和总体复杂度
综合这些步骤,总计算复杂度主要取决于两次大型矩阵乘法:O ( N 2 d k + N 2 d v ) O(N^2 d_k + N^2 d_v) O ( N 2 d k + N 2 d v ) 。
在许多标准Transformer配置中,d k d_k d k 和 d v d_v d v 的维度与整体模型嵌入维度 d m o d e l d_{model} d m o d e l 成比例 (通常 d k = d v = d m o d e l / h d_k = d_v = d_{model} / h d k = d v = d m o d e l / h ,其中 h h h 是注意力头的数量)。因此,复杂度通常概括为:
O ( N 2 ⋅ d m o d e l ) O(N^2 \cdot d_{model}) O ( N 2 ⋅ d m o d e l )
这种对序列长度 N N N 的二次方依赖是重要瓶颈。尽管这些运算在序列维度上高度并行化 (与循环模型不同),但总运算次数随 N N N 的增长而迅速增加。
内存复杂度
除了计算之外,还有显著的内存需求。中间注意力得分矩阵 Q K T QK^T Q K T (softmax之前或之后) 的维度为 ( N × N ) (N \times N) ( N × N ) 。存储该矩阵需要:
O ( N 2 ) O(N^2) O ( N 2 )
的内存。对于长序列 (例如 N > 4096 N > 4096 N > 4096 ),存储一个 ( N × N ) (N \times N) ( N × N ) 的浮点数矩阵可能会超出典型加速器 (如GPU) 的内存容量,甚至在考虑激活值、梯度和模型参数所需的内存之前。
该图表显示了计算成本的迅速分化。随着序列长度 N N N 的增加,标准自注意力机制的 O ( N 2 ) O(N^2) O ( N 2 ) 成本迅速超过线性 O ( N ) O(N) O ( N ) 增长,使其对于非常长的序列变得不切实际。Y轴使用对数刻度以适应较大的数值范围。
对长序列的影响
这种二次方计算和内存复杂度严重限制了标准Transformer应用于涉及极长序列的任务,例如:
处理被视为图像块序列的整个高分辨率图像。
分析完整的长篇文档或书籍。
对长篇时间序列数据建模 (例如,音频、传感器读数)。
处理基因组序列。
即使配备强大的硬件,序列长度超过几千个token后,在训练和推理期间,从运行时和内存角度来看都会变得困难。此限制直接促使了本章后续部分讨论的其他注意力机制和架构变体的发展,这些旨在减少这种二次方依赖。