并行注意力计算涉及将原始的查询 (Q Q Q )、键 (K K K ) 和值 (V V V ) 投影到 h h h 个不同的子空间中,对每个头 i i i 使用不同的学习到的线性变换 W i Q , W i K , W i V W^Q_i, W^K_i, W^V_i W i Q , W i K , W i V 。对于每个头,注意力计算独立且同时地进行。这种并行处理是多头注意力的一种明确特征,并且对它的效率和计算特性有很大的贡献。
对于每个头 i i i (i i i 的范围从 1 到 h h h ),我们准确地按照之前定义的方式计算带缩放的点积注意力,但使用的是该头特有的投影 矩阵 Q i Q_i Q i 、K i K_i K i 和 V i V_i V i :
head i = Attention ( Q i , K i , V i ) = softmax ( Q i K i T d k i ) V i \text{head}_i = \text{Attention}(Q_i, K_i, V_i) = \text{softmax}\left(\frac{Q_i K_i^T}{\sqrt{d_{k_i}}}\right) V_i head i = Attention ( Q i , K i , V i ) = softmax ( d k i Q i K i T ) V i
这里:
Q i = Q W i Q Q_i = Q W^Q_i Q i = Q W i Q 表示为头 i i i 投影的查询。
K i = K W i K K_i = K W^K_i K i = K W i K 表示为头 i i i 投影的键。
V i = V W i V V_i = V W^V_i V i = V W i V 表示为头 i i i 投影的值。
d k i d_{k_i} d k i 是头 i i i 内部 键(和查询)的维度。这个缩放因子,类似于单头带缩放点积注意力中使用的,能在训练期间稳定梯度。
正确管理维度是很重要的。如果输入嵌入维度是 d model d_{\text{model}} d model 并且我们使用 h h h 个头,投影通常被设计成使每个头的键、查询 (d k i d_{k_i} d k i ) 和值 (d v i d_{v_i} d v i ) 的维度相等:d k i = d v i = d model / h d_{k_i} = d_{v_i} = d_{\text{model}} / h d k i = d v i = d model / h 。这种划分确保了总计算成本与使用完整 d model d_{\text{model}} d model 维度的单头注意力机制相似,同时将表示能力分配到多个头。此外,它保证了当所有头的输出稍后被拼接时,结果维度与后续层所需的输入维度 d model d_{\text{model}} d model 相匹配,从而保持了模型架构的统一性。
假设输入序列的长度为 N N N (词元数量),为了简化,忽略批处理维度,头 i i i 的矩阵形状通常是:
Q i Q_i Q i : N × d k i N \times d_{k_i} N × d k i
K i K_i K i : N × d k i N \times d_{k_i} N × d k i (因为在自注意力中,键和查询来自相同的输入序列)
V i V_i V i : N × d v i N \times d_{v_i} N × d v i
因此,头 i i i 的注意力计算输出,记为 head i \text{head}_i head i ,将具有形状 N × d v i N \times d_{v_i} N × d v i 。由于我们通常设置 d v i = d k i = d model / h d_{v_i} = d_{k_i} = d_{\text{model}} / h d v i = d k i = d model / h ,输出形状为 N × ( d model / h ) N \times (d_{\text{model}} / h) N × ( d model / h ) 。
从计算角度看,这种结构非常适合并行处理。现代深度学习框架和像 GPU 这样的硬件擅长执行大型矩阵乘法。所有 h h h 个头的计算通常可以并行执行,而不是顺序地迭代每个头。这通常通过在注意力计算之前重塑投影的 Q、K 和 V 张量来实现,使其包含一个独立的“头”维度。例如,表示批处理查询的张量可以从 (batch_size, seq_len, d_model) 重塑为 (batch_size, num_heads, seq_len, d_k_i)。用于 Q i K i T d k i \frac{Q_i K_i^T}{\sqrt{d_{k_i}}} d k i Q i K i T 项的批量矩阵乘法(matmul),然后是 softmax 和与 V i V_i V i 的最终 matmul,可以同时高效地在批处理和头维度上运行。
对每个头,使用其特有的投影 Q、K、V 矩阵 (Q i , K i , V i Q_i, K_i, V_i Q i , K i , V i ) 执行独立的带缩放点积注意力计算。输出(head 1 , . . . , head h \text{head}_1, ..., \text{head}_h head 1 , ... , head h ),每个的维度为 N × d v i N \times d_{v_i} N × d v i ,在传递到下一阶段之前并行生成。注意,d k i d_{k_i} d k i 和 d v i d_{v_i} d v i 代表 d model / h d_{\text{model}}/h d model / h 。
这种并行结构的主要优点不限于计算效率。它允许每个注意力头可能专门学习不同类型的关系,或者同时关注来自不同表示子空间的信息。例如,一个头可以学习侧重于局部的句法依赖(如形容词与名词的一致性),而另一个头则捕捉更远距离的语义联系(如跨句子的指代消解),还有一个头可能侧重于位置关系。单个注意力机制将被迫平均这些可能不同的信号,这可能会稀释信息。多头注意力为信息流提供了多个独立的“通道”,让模型能够汇集多样化的关系信息,并最终构建更丰富、更具上下文意识的表示。
这些并行计算的输出,head 1 , head 2 , . . . , head h \text{head}_1, \text{head}_2, ..., \text{head}_h head 1 , head 2 , ... , head h ,捕获了输入序列内部关系的不同方面。它们现在已准备好在下一步中组合:先拼接,再进行最终的线性投影。