趋近智
大师班
缩放点积注意力是 Transformer 注意力机制的基本组成部分。这种机制允许模型在处理特定元素时,衡量输入序列不同部分的重要性。它不依赖循环,而是根据从输入中获得的查询、键和值之间的相互作用来计算注意力分数。
缩放点积注意力的核心计算定义为:
注意力(Q,K,V)=softmax(dkQKT)V让我们分析一下组成部分和实现步骤:
查询 (Q)、键 (K)、值 (V) 矩阵: 这些矩阵通常是输入嵌入的投影。对于给定的输入序列元素(表示为一个向量),我们生成:
查询 向量:表示当前元素正在寻找信息。键 向量:表示提供信息的元素,用于计算与查询的兼容性。值 向量:表示提供信息的元素的实际内容。
如果我们有一个序列批次,Q、K 和 V 变为矩阵,其中每行对应序列中的一个元素。它们的维度通常为 [batch_size,seq_len,dmodel];对于单个注意力头进行投影后,Q 和 K 的维度为 [batch_size,seq_len,dk],而 V 的维度为 [batch_size,seq_len,dv]。通常,dk=dv。计算点积 (QKT): 第一步是计算查询矩阵 Q 和键矩阵 KT 的转置之间的点积。此操作计算每个查询应关注每个键的程度。更高的点积表示查询和键之间具有更高的相关性或兼容性。结果矩阵通常称为 分数 或 能量,其维度为 [batch_size,seq_lenq,seq_lenk],其中 seq_lenq 是查询的序列长度,seq_lenk 是键的序列长度(在自注意力中它们通常是相同的)。
缩放 (dk...): 然后,通过除以键向量维度 dk 的平方根来缩放分数。这种缩放对于稳定训练过程很重要。没有它,对于较大的 dk 值,点积的幅度可能会变得非常大。softmax 函数的输入过大可能导致梯度极小,从而使学习变得困难。缩放确保 softmax 输入的方差保持合理。
应用掩码(可选): 在许多情况下,我们需要阻止关注某些位置。这通过在 softmax 步骤之前进行掩码来实现。
True 或 1)。我们会在这些位置的分数上添加一个大的负数(如 -1e9 或负无穷)。应用 Softmax: softmax 函数按行应用于缩放(并可能被掩盖)后的分数。这会将分数转换为概率分布,其中每个值表示一个查询分配给一个键的注意力权重。每个查询的权重总和为 1。结果矩阵通常称为 注意力权重,其维度为 [batch_size,seq_lenq,seq_lenk]。
乘以值 (...V): 最后,注意力权重矩阵乘以值矩阵 V。这一步计算值向量的加权和,其中权重由注意力概率决定。获得更高注意力权重的元素对输出的贡献更大。缩放点积注意力层的输出维度为 [batch_size,seq_lenq,dv]。
让我们将这些步骤转换为 PyTorch 函数。我们假设输入 query、key 和 value 是 3D 张量,表示序列批次,可能已经为特定的注意力头进行了投影。
import torch
import torch.nn.functional as F
import math
def scaled_dot_product_attention(
query: torch.Tensor,
torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor]:
"""
计算缩放点积注意力。
参数:
query: 查询张量;形状为 (batch_size, num_heads, seq_len_q, d_k)
或 (batch_size, seq_len_q, d_k) (如果为单头)。
key: 键张量;形状为 (batch_size, num_heads, seq_len_k, d_k)
或 (batch_size, seq_len_k, d_k) (如果为单头)。
value: 值张量;形状为 (batch_size, num_heads, seq_len_v, d_v)
或 (batch_size, seq_len_v, d_v) (如果为单头)。
注意:seq_len_k 和 seq_len_v 必须相同。
mask: 可选的掩码张量;形状应可广播到
(batch_size, num_heads, seq_len_q, seq_len_k)。
`True` 或 `1` 的位置将被掩盖(设为 -inf)。
返回:
包含以下内容的元组:
- output:注意力输出张量;
形状为 (batch_size, num_heads, seq_len_q, d_v)
或 (batch_size, seq_len_q, d_v) (如果为单头)。
- attention_weights:注意力权重张量;
形状为 (batch_size, num_heads, seq_len_q, seq_len_k)
或 (batch_size, seq_len_q, seq_len_k) (如果为单头)。
"""
# 确保维度与矩阵乘法兼容
# K 需要形状 (..., d_k, seq_len_k) 才能与 Q (..., seq_len_q, d_k) 进行矩阵乘法
# 结果形状: (..., seq_len_q, seq_len_k)
d_k = query.size(-1)
scores = (torch.matmul(query, key.transpose(-2, -1))
/ math.sqrt(d_k))
# 如果提供了掩码,则应用掩码(将掩盖位置设为一个大的负值)
if mask is not None:
# 确保掩码具有兼容的维度或可以广播
# 常见掩码形状:(batch_size, 1, 1, seq_len_k) 用于填充掩码
# (batch_size, 1, seq_len_q, seq_len_k) 用于组合掩码
# 我们添加一个大的负值,而不是直接使用布尔掩码
# 以确保与各种 PyTorch 版本和操作兼容。
# 当掩码为 True(或 1)时,我们希望用 -inf 替换分数。
scores = scores.masked_fill(mask == True, float('-inf'))
# 或者使用一个大的负数,如 -1e9
# 应用 softmax 以获得注意力概率
# Softmax 应用于最后一个维度 (seq_len_k)
attention_weights = F.softmax(scores, dim=-1)
# 检查 softmax 后可能出现的 NaN,这可能发生在某一行中所有分数都为 -inf 的情况下
# 这可能表明掩码或输入数据存在问题
if torch.isnan(attention_weights).any():
print("警告:在注意力权重中检测到 NaN。 "
"请检查掩码或输入数据。")
# (可选)处理 NaN,例如,将其设为 0,
# 尽管这可能会隐藏潜在问题。
# attention_weights = torch.nan_to_num(attention_weights)
# 权重乘以值
# 结果形状: (..., seq_len_q, d_v)
output = torch.matmul(attention_weights, value)
return output, attention_weights
# 示例用法(为简单起见,假设为单头)
batch_size = 2
seq_len_q = 5 # 查询序列长度
seq_len_k = 7 # 键/值序列长度
d_k = 64 # 键/查询的维度
d_v = 128 # 值的维度
# 虚拟张量
query_tensor = torch.randn(batch_size, seq_len_q, d_k)
key_tensor = torch.randn(batch_size, seq_len_k, d_k)
value_tensor = torch.randn(batch_size, seq_len_k, d_v) # seq_len_k == seq_len_v
# 填充掩码示例(掩盖键/值序列的最后两个元素)
padding_mask = torch.zeros(batch_size, 1, seq_len_k, dtype=torch.bool)
padding_mask[:, :, -2:] = True # 掩盖位置 5 和 6
# 计算注意力
output_tensor, attention_weights_tensor = scaled_dot_product_attention(
query_tensor,
key_tensor,
value_tensor,
mask=padding_mask
)
print("输出形状:", output_tensor.shape) # 预期:[2, 5, 128]
print("注意力权重形状:", attention_weights_tensor.shape) # 预期:[2, 5, 7]
print("批次 0 中第一个查询的注意力权重 "
"(最后两个键已被掩盖):")
print(attention_weights_tensor[0, 0, :])
此函数封装了核心逻辑。请注意,掩码需要应用在 softmax 之前。使用一个大的负数配合 masked_fill 有效地阻止了被掩盖的位置在 softmax 归一化后对加权和做出贡献。该函数返回最终的加权输出和注意力权重本身,这有助于分析或可视化(正如我们将在第 23 章中看到的)。
这个基本组成部分现在将在多头注意力机制中使用,我们接下来将实现它。多头注意力将并行地多次运行此缩放点积注意力,并使用查询、键和值的不同学习到的投影。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造