标准自注意力机制的二次复杂度,即 O ( N 2 d ) O(N^2 d) O ( N 2 d ) (其中 N N N 是序列长度,d d d 是模型维度),在处理长序列时是一个主要瓶颈。尽管功能强大,但随着 N N N 的增加,计算每对标记之间的注意力分数会变得计算成本过高。
有几种方法旨在减轻这种计算负担。一个重要方向是利用低秩投影来近似注意力机制,Linformer 模型就是其中的一个例子。
低秩假设
Linformer 的核心思路是基于这样一个假设:尽管自注意力机制有能力模拟整个序列中复杂的关联,但它通常可以通过一个低秩矩阵进行近似。本质上,N × N N \times N N × N 的注意力矩阵 P = S o f t m a x ( Q K T d k ) P = Softmax(\frac{QK^T}{\sqrt{d_k}}) P = S o f t ma x ( d k Q K T ) 可能在结构上存在冗余。这意味着从输入上下文(由值 V V V 表示)到输出上下文(加权和 P V PV P V )的映射不一定需要任意 N × N N \times N N × N 矩阵的完整表达能力。如果这个假设成立,我们就可以通过使用压缩表示来大幅提高效率。
Linformer:键和值的投影
Linformer(线性 Transformer)提出了一种巧妙的方法来实现线性复杂度,即引入学习到的投影矩阵 E i E_i E i 和 F i F_i F i 。Linformer 不计算完整的 N × N N \times N N × N 注意力矩阵,而是在注意力计算之前 沿序列长度维度投影键 (K K K ) 和值 (V V V ) 矩阵。
设输入序列长度为 N N N ,头维度为 d k d_k d k (用于键/查询)或 d v d_v d v (用于值)。原始矩阵为:
查询 Q ∈ R N × d k Q \in \mathbb{R}^{N \times d_k} Q ∈ R N × d k
键 K ∈ R N × d k K \in \mathbb{R}^{N \times d_k} K ∈ R N × d k
值 V ∈ R N × d v V \in \mathbb{R}^{N \times d_v} V ∈ R N × d v
Linformer 引入了两个投影矩阵 E ∈ R k × N E \in \mathbb{R}^{k \times N} E ∈ R k × N 和 F ∈ R k × N F \in \mathbb{R}^{k \times N} F ∈ R k × N ,其中 k k k 是一个投影维度,远小于 N N N (k ≪ N k \ll N k ≪ N )。这些矩阵用于创建投影键矩阵和投影值矩阵:
K p r o j = E K ( 维度 k × d k ) K_{proj} = E K \quad (\text{维度 } k \times d_k) K p ro j = E K ( 维度 k × d k )
V p r o j = F V ( 维度 k × d v ) V_{proj} = F V \quad (\text{维度 } k \times d_v) V p ro j = F V ( 维度 k × d v )
请注意,投影如何将序列维度从 N N N 降低到 k k k 。主要步骤是注意力分数现在在原始查询矩阵 Q Q Q 和投影 键矩阵 K p r o j K_{proj} K p ro j 之间计算:
P p r o j = S o f t m a x ( Q K p r o j T d k ) ( 维度 N × k ) P_{proj} = Softmax\left(\frac{Q K_{proj}^T}{\sqrt{d_k}}\right) \quad (\text{维度 } N \times k) P p ro j = S o f t ma x ( d k Q K p ro j T ) ( 维度 N × k )
最终输出通过将这个投影注意力分数矩阵 P p r o j P_{proj} P p ro j 乘以投影 值矩阵 V p r o j V_{proj} V p ro j 获得:
A t t e n t i o n L i n f o r m e r ( Q , K , V ) = P p r o j V p r o j ( 维度 N × d v ) Attention_{Linformer}(Q, K, V) = P_{proj} V_{proj} \quad (\text{维度 } N \times d_v) A tt e n t i o n L in f or m er ( Q , K , V ) = P p ro j V p ro j ( 维度 N × d v )
复杂度分析
让我们分析计算复杂度。原始注意力计算主要由 Q K T Q K^T Q K T 矩阵乘法决定,这需要 O ( N 2 d k ) O(N^2 d_k) O ( N 2 d k ) 时间,以及随后与 V V V 的乘法,这需要 O ( N 2 d v ) O(N^2 d_v) O ( N 2 d v ) 时间。
在 Linformer 中:
投影键 K:E K E K E K 需要 O ( N k d k ) O(Nk d_k) O ( N k d k ) 时间。
投影值 V:F V F V F V 需要 O ( N k d v ) O(Nk d_v) O ( N k d v ) 时间。
计算 Q K p r o j T Q K_{proj}^T Q K p ro j T :这涉及将一个 N × d k N \times d_k N × d k 矩阵乘以一个 d k × k d_k \times k d k × k 矩阵,结果需要 O ( N k d k ) O(Nk d_k) O ( N k d k ) 时间。
计算 P p r o j V p r o j P_{proj} V_{proj} P p ro j V p ro j :这涉及将一个 N × k N \times k N × k 矩阵乘以一个 k × d v k \times d_v k × d v 矩阵,结果需要 O ( N k d v ) O(Nk d_v) O ( N k d v ) 时间。
由于 k k k 被选择为使得 k ≪ N k \ll N k ≪ N ,Linformer 注意力机制的总体复杂度变为 O ( N k ) O(Nk) O ( N k ) ,这相对于序列长度 N N N 是线性的。这是对标准 O ( N 2 ) O(N^2) O ( N 2 ) 复杂度的大幅改进。
标准自注意力与 Linformer 投影注意力的计算流程比较。Linformer 引入了投影矩阵 (E, F),以在计算注意力分数之前降低键和值沿着序列长度轴的维度。
实现考量
投影矩阵: 投影矩阵 E E E 和 F F F 通常在训练期间学习。它们可以在不同的注意力头甚至层之间共享,以进一步减少参数数量。一个常见实现是使用简单的线性层作用于转置的键和值矩阵 (K T K^T K T , V T V^T V T ),以高效地执行投影。
k 的选择: 投影维度 k k k 是一个超参数。通常使用 128、256 或 512 等值,这些值远小于典型序列长度(数千或数万),Linformer 在这些情况下变得有优势。这个选择影响计算效率和模型准确性之间的权衡。较小的 k k k 运算更快,但可能导致更大的近似误差。
理论保证: Linformer 论文提供了理论分析,表明在某些假设下,自注意力矩阵确实是低秩的,并且可以通过这种投影方法很好地近似。
优点和缺点
优点:
线性时间复杂度: 将注意力时间复杂度从 O ( N 2 ) O(N^2) O ( N 2 ) 降低到 O ( N k ) O(Nk) O ( N k ) 。
线性空间复杂度: 减少了存储注意力矩阵所需的内存占用。
可扩展性: 能够处理比标准 Transformer 长得多的序列。
修改简单: 通过增加投影层,实现起来相对简单。
缺点:
近似误差: 作为一种近似方法,它可能无法捕获完整注意力矩阵的所有细节,可能导致在某些任务上性能下降,特别是对于 N 2 N^2 N 2 计算量可控的较短序列。
超参数调优: 需要调整投影维度 k k k 。
低秩假设: 其有效性依赖于底层假设,即注意力矩阵可以很好地由低秩结构近似,但这对于所有数据或任务可能并非都同样适用。
Linformer 代表着构建更高效 Transformer 模型的一个重要步骤。通过质疑完整二次注意力计算的必要性,并借助低秩近似的思路,它提供了一种实用的方法来扩展 Transformer 到更长的序列,为涉及大量文档、高分辨率图像或长篇音频的应用带来了可能性。