趋近智
这个实践练习侧重于实现主要的注意力计算,特别是缩放点积注意力方式。这种方式是Transformer处理信息的根本,它使得模型能够衡量输入序列中不同元素彼此之间的分量。
我们将实现以下公式:
注意力(Q,K,V)=softmax(dkQKT)V这个函数以查询 (Q)、键 (K) 和值 (V) 矩阵作为输入,同时还有键向量的维度 (dk) 用于缩放。另外,它还可以应对一个掩码,以阻止对某些位置(比如填充标记或解码器中的未来标记)进行注意力计算。
我们将使用 PyTorch 进行此次实现,但其思想可顺畅地应用于其他深度学习架构,例如 TensorFlow。请确保您已安装 PyTorch。我们还需要 math 库来进行平方根计算。
import torch
import torch.nn.functional as F
import math
让我们定义一个 Python 函数 scaled_dot_product_attention 来进行计算。它将接受 Q、K、V 的张量,以及一个可选的 mask。
def scaled_dot_product_attention(query, key, value, mask=None):
"""
计算缩放点积注意力。
参数:
query (torch.Tensor): 查询张量;形状 (batch_size, ..., seq_len_q, d_k)
(torch.Tensor): 键张量;形状 (batch_size, ..., seq_len_k, d_k)
value (torch.Tensor): 值张量;形状 (batch_size, ..., seq_len_v, d_v)
注意:seq_len_k 和 seq_len_v 必须相同。
mask (torch.Tensor, optional): 掩码张量;形状必须可以广播
到 (batch_size, ..., seq_len_q, seq_len_k)。
默认为 None。
返回:
torch.Tensor: 输出张量;形状 (batch_size, ..., seq_len_q, d_v)
torch.Tensor: 注意力权重;形状 (batch_size, ..., seq_len_q, seq_len_k)
"""
# 获取向量的维度
d_k = query.size(-1)
# 1. 计算点积:Q * K^T
# 结果形状:(batch_size, ..., seq_len_q, seq_len_k)
```python
attention_scores = torch.matmul(query, key.transpose(-2, -1))
# 2. 缩放分数
attention_scores = attention_scores / math.sqrt(d_k)
# 3. 应用掩码(如果提供)
# 掩码指示要忽略的位置(例如,填充)。
# 在 softmax 之前,我们将一个较大的负数 (-1e9) 添加到这些位置。
if mask is not None:
# 确保掩码具有兼容的形状
attention_scores = attention_scores.masked_fill(mask == 0, -1e9)
# 4. 应用 softmax 获取注意力权重
# Softmax 应用于最后一个维度(`seq_len_k`)
# 结果形状:(batch_size, ..., seq_len_q, seq_len_k)
attention_weights = F.softmax(attention_scores, dim=-1)
# 5. 将权重乘以值向量 V
# 结果形状:(batch_size, ..., seq_len_q, d_v)
output = torch.matmul(attention_weights, value)
return output, attention_weights
让我们分解函数中的步骤:
1. **矩阵乘法 ($QK^T$)**:我们计算每个查询向量与所有键向量之间的点积。`torch.matmul` 进行批处理和矩阵乘法。`key.transpose(-2, -1)` 操作交换键张量的最后两个维度,从而有效地转置键矩阵以进行乘法。此步骤算出查询和键之间的初始对齐分数。
2. **缩放**:分数除以维度 ($d_k$) 的平方根。如前所述,这种缩放避免点积变得过大,否则可能会使得 softmax 函数进入梯度很小的区域,从而影响学习。
3. **掩码(非必需)**:如果提供了 `mask`,我们在此处应用它。掩码一般在应避免注意力的位置(比如填充标记或序列中的未来位置)为 `0`,在其他位置为 `1`。我们使用 `masked_fill` 将被掩码位置 (`mask == 0`) 的分数替换为一个非常大的负数 (`-1e9`)。接下来应用 softmax 时,这些位置将获得接近零的概率。
4. **Softmax**:`F.softmax` 函数应用于最后一个维度(序列长度维度,`seq_len_k`)。这会将缩放后的分数变为概率分布,表示注意力权重。每个查询位置的权重在所有键位置上相加将为 1。
5. **矩阵乘法(权重 * V)**:最后,注意力权重乘以值 ($V$) 张量。这会计算值向量的加权和,其中权重由注意力分布决定。结果是注意力方式的输出,表示输入序列中每个查询位置的上下文有关的信息得到突出。
该函数返回最终的输出张量和注意力权重,这对于分析和可视化很有帮助。
### 使用示例
让我们创建一些示例张量,看看函数的效果。我们假设批大小为 1,序列长度为 4,嵌入维度 ($d_k$, $d_v$) 为 8。在自注意力中,Q、K 和 V 一般来源于相同的输入序列,因此 `seq_len_q`、`seq_len_k` 和 `seq_len_v` 通常是相同的。
```python
# 示例参数
batch_size = 1
seq_len = 4
d_k = 8 # 键/查询的维度
d_v = 8 # 值的维度
# 创建随机查询、值张量
# 在实际模型中,这些将来自通过线性层投影的输入嵌入
query = torch.randn(batch_size, seq_len, d_k)
key = torch.randn(batch_size, seq_len, d_k)
value = torch.randn(batch_size, seq_len, d_v)
# 计算注意力
output, attention_weights = scaled_dot_product_attention(query, key, value)
print("输入查询形状:", query.shape)
print("输入键形状:", key.shape)
print("输入值形状:", value.shape)
print("\n输出形状:", output.shape)
print("注意力权重形状:", attention_weights.shape)
print("\n注意力权重样本(第一个批次元素):\n", attention_weights[0])
您应该看到类似于此的输出(值会因随机性而有所不同):
输入查询形状: torch.Size([1, 4, 8])
输入键形状: torch.Size([1, 4, 8])
输入值形状: torch.Size([1, 4, 8])
输出形状: torch.Size([1, 4, 8])
注意力权重形状: torch.Size([1, 4, 4])
注意力权重样本(第一个批次元素):
tensor([[0.1813, 0.3056, 0.3317, 0.1814],
[0.2477, 0.2080, 0.3401, 0.2042],
[0.2880, 0.1807, 0.2523, 0.2790],
[0.3139, 0.1774, 0.2614, 0.2473]])
请注意,输出形状 (1, 4, 8) 与查询和值序列长度以及值维度 (dv) 相符。注意力权重形状 (1, 4, 4) 表示从 4 个查询位置中的每个位置到 4 个键位置中的每个位置的注意力分数。注意力权重样本中的每一行近似相加为 1。
可视化注意力权重可以给出模型在处理特定元素时,对输入序列的哪些部分侧重的认知。让我们为刚刚计算的 attention_weights 使用一个简单的热力图。
注意力权重以热力图形式呈现。每个单元格 (i, j) 表示从查询位置 i 到键位置 j 的注意力权重。深蓝色表示更高的注意力。
这个可视化显示了每个查询位置(行)对每个位置(列)的关注程度。在实际应用中,比如将“hello world”翻译成法语时,您可能会看到在生成法语单词“world”时,注意力方式主要集中在输入单词“world”上。
在本节中,您实现了缩放点积注意力,它是 Transformer 中注意力功能的主要计算单元。您了解了如何计算查询和键之间的分数、对其进行缩放、另外可以选择应用掩码、使用 softmax 进行标准化以获得权重,最后计算值的加权和。这个函数是多头注意力方式中使用的构建单元,使得 Transformer 能够良好地处理序列信息。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造