在为每个 h 个头并行计算缩放点积注意力后,我们会得到 h 个不同的输出向量,通常表示为 head1,head2,...,headh。每个 headi 都是通过对原始查询、键和值使用投影版本来应用注意力机制的结果。
headi=Attention(XWiQ,XWiK,XWiV)
其中 X 表示输入序列嵌入(或来自上一层的输出),并且 WiQ、WiK、WiV 是第 i 个头的参数矩阵。每个 headi 的维度为 dv,其中 dv=dmodel/h。这种划分使得所有头的计算成本与具有完整 dmodel 维度的单头注意力机制的计算成本相当。
然而,Transformer 中的后续层,例如逐位置前馈网络和残差连接,需要一个具有模型隐藏维度 dmodel 的单个输入张量。因此,我们需要一种机制将这些并行头的输出整合回一个单一的表示形式。
组合头输出:拼接
整合不同头所捕获信息的首要步骤是直接的:拼接。所有 h 个头的输出向量沿其特征维度进行拼接。
如果每个 headi 是一个形状为(序列长度,dv)的矩阵,拼接操作会将这些矩阵并排堆叠,产生一个形状为(序列长度,h×dv)的新矩阵。由于我们将 dv 定义为 dmodel/h,得到的拼接矩阵具有(序列长度,dmodel)的维度。
拼接(head1,head2,...,headh)∈R序列长度×(h⋅dv)=R序列长度×dmodel
此操作有效汇聚了每个头所关注的不同表征子空间中获取的见解。
流程图显示了单个注意力头的输出如何被拼接,然后通过最终线性层进行投影。
最终线性投影
拼接汇聚了头部的输出,但结果张量只是简单地将这些专用表示并置。为了让这些表示能够有效交互和组合,并确保输出维度与后续层(特别是残差连接)所需的 dmodel 相匹配,会应用一个最终线性投影。
这种投影通过另一个学习到的权重矩阵 WO 实现,其维度为 (h×dv,dmodel),简化后为 (dmodel,dmodel)。拼接后的输出与此权重矩阵相乘:
多头输出=拼接(head1,...,headh)WO
这种最终线性变换作用显著:
- 信息混合: 它使得从不同头(代表不同子空间)中学习到的信息得以组合和整合。线性层充当一个学习到的组合函数。
- 维度匹配: 它确保输出张量具有Transformer块架构其余部分所需的精确 dmodel 维度,从而能够进行残差连接(加和归一化)等操作。
根本上,多头机制首先将模型的表示能力分解为多个子空间(h 个头,每个维度为 dv),允许每个头在其子空间内专门关注输入的不同方面,然后使用拼接,再进行线性投影(WO),将这些专业化的视图合并回一个维度为 dmodel 的单一、更丰富的表示。
这种结构使模型能够同时捕获多种关系模式(例如,短距离依赖、长距离依赖、句法关系),这对于单一注意力机制而言将更具挑战。投影矩阵(每个头的 WiQ,WiK,WiV)和最终输出投影矩阵(WO)的所有参数都在模型训练期间进行端到端学习。