趋近智
缩放点积注意力使得模型能够区分序列中不同token的优先级。然而,仅进行一次这样的计算可能会迫使注意力机制对不同类型的关联进行平均。设想尝试理解一个句子,例如“这只疲惫的动物没有过马路,因为它太宽了。”当关注“它”这个词时,单一的注意力机制可能难以同时良好地捕捉“疲惫的动物”的关联和“街道宽度”的关联。
多头注意力机制通过多次并行运行缩放点积注意力过程来解决此问题,每次都对原始的查询、键和值使用不同的学习变换。这使得每个“头”能够关注信息中不同的方面或表征子空间。
以下是分步过程:
线性投影: 多头注意力机制不使用单一的查询(Q)、键(K)和值(V)矩阵集合,而是首先创建 个不同的矩阵集合,其中 是注意力头的数量(一个超参数)。对于每个头 (从 到 ),原始输入Q、K和V矩阵(在自注意力的情况下通常源自相同的输入序列嵌入)使用学习到的权重矩阵 、 和 进行投影。
通常,这些投影矩阵的维度小于原始嵌入维度()。如果输入嵌入维度为 ,每个头通常使用维度 进行操作。这确保了总计算成本与具有完整维度的单一头注意力相似。这些权重矩阵()对于每个头都是独有的,并在训练过程中学习。
并行注意力计算: 然后,这些投影集合()中的每一个都同时送入各自的缩放点积注意力机制。这产生了 个独立的输出矩阵,我们称它们为 :
每个 矩阵根据头 学习到的特定投影来捕获注意力信息。因为投影不同(对于每个 , 都不同),每个头可能学习关注输入序列中不同类型的关联或特征。
拼接: 所有 个注意力头的输出沿着特征维度拼接在一起。如果每个 的维度是 ,拼接后的矩阵维度将是 。由于我们通常设置 ,拼接后的矩阵维度变为 ,这与原始输入嵌入维度相匹配。
最终线性投影: 此拼接后的输出随后会经过一个最终的线性投影层,由另一个学习到的权重矩阵 参数化。这个投影混合了不同头学习到的信息,并生成多头注意力层的最终输出,其维度通常为 。
这个完整的多头注意力模块随后可作为更大的Transformer架构中的一个组成部分使用,替换单一的缩放点积注意力机制。
下图展现了信息流经一个包含 个头的多头注意力模块的过程。
该图显示了输入Q、K和V矩阵如何首先为每个 个注意力头独立投影。然后,缩放点积注意力并行应用于每个投影集。产生的注意力输出被拼接并通过一个最终的线性层,以生成多头注意力输出。
通过允许不同头学习不同的投影矩阵(),多头注意力机制使得模型能够共同关注不同位置上来自不同表征子空间的信息,从而相较于使用单一注意力机制,带来了更丰富、更有效的表征。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造