这个实践练习侧重于实现主要的注意力计算,特别是缩放点积注意力方式。这种方式是Transformer处理信息的根本,它使得模型能够衡量输入序列中不同元素彼此之间的分量。我们将实现以下公式:$$ \text{注意力}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V $$这个函数以查询 ($Q$)、键 ($K$) 和值 ($V$) 矩阵作为输入,同时还有键向量的维度 ($d_k$) 用于缩放。另外,它还可以应对一个掩码,以阻止对某些位置(比如填充标记或解码器中的未来标记)进行注意力计算。环境准备我们将使用 PyTorch 进行此次实现,但其思想可顺畅地应用于其他深度学习架构,例如 TensorFlow。请确保您已安装 PyTorch。我们还需要 math 库来进行平方根计算。import torch import torch.nn.functional as F import math实现注意力函数让我们定义一个 Python 函数 scaled_dot_product_attention 来进行计算。它将接受 $Q$、$K$、$V$ 的张量,以及一个可选的 mask。def scaled_dot_product_attention(query, key, value, mask=None): """ 计算缩放点积注意力。 参数: query (torch.Tensor): 查询张量;形状 (batch_size, ..., seq_len_q, d_k) (torch.Tensor): 键张量;形状 (batch_size, ..., seq_len_k, d_k) value (torch.Tensor): 值张量;形状 (batch_size, ..., seq_len_v, d_v) 注意:seq_len_k 和 seq_len_v 必须相同。 mask (torch.Tensor, optional): 掩码张量;形状必须可以广播 到 (batch_size, ..., seq_len_q, seq_len_k)。 默认为 None。 返回: torch.Tensor: 输出张量;形状 (batch_size, ..., seq_len_q, d_v) torch.Tensor: 注意力权重;形状 (batch_size, ..., seq_len_q, seq_len_k) """ # 获取向量的维度 d_k = query.size(-1) # 1. 计算点积:Q * K^T # 结果形状:(batch_size, ..., seq_len_q, seq_len_k) ```python attention_scores = torch.matmul(query, key.transpose(-2, -1))# 2. 缩放分数 attention_scores = attention_scores / math.sqrt(d_k) # 3. 应用掩码(如果提供) # 掩码指示要忽略的位置(例如,填充)。 # 在 softmax 之前,我们将一个较大的负数 (-1e9) 添加到这些位置。 if mask is not None: # 确保掩码具有兼容的形状 attention_scores = attention_scores.masked_fill(mask == 0, -1e9) # 4. 应用 softmax 获取注意力权重 # Softmax 应用于最后一个维度(`seq_len_k`) # 结果形状:(batch_size, ..., seq_len_q, seq_len_k) attention_weights = F.softmax(attention_scores, dim=-1) # 5. 将权重乘以值向量 V # 结果形状:(batch_size, ..., seq_len_q, d_v) output = torch.matmul(attention_weights, value) return output, attention_weights 让我们分解函数中的步骤: 1. **矩阵乘法 ($QK^T$)**:我们计算每个查询向量与所有键向量之间的点积。`torch.matmul` 进行批处理和矩阵乘法。`key.transpose(-2, -1)` 操作交换键张量的最后两个维度,从而有效地转置键矩阵以进行乘法。此步骤算出查询和键之间的初始对齐分数。 2. **缩放**:分数除以维度 ($d_k$) 的平方根。如前所述,这种缩放避免点积变得过大,否则可能会使得 softmax 函数进入梯度很小的区域,从而影响学习。 3. **掩码(非必需)**:如果提供了 `mask`,我们在此处应用它。掩码一般在应避免注意力的位置(比如填充标记或序列中的未来位置)为 `0`,在其他位置为 `1`。我们使用 `masked_fill` 将被掩码位置 (`mask == 0`) 的分数替换为一个非常大的负数 (`-1e9`)。接下来应用 softmax 时,这些位置将获得接近零的概率。 4. **Softmax**:`F.softmax` 函数应用于最后一个维度(序列长度维度,`seq_len_k`)。这会将缩放后的分数变为概率分布,表示注意力权重。每个查询位置的权重在所有键位置上相加将为 1。 5. **矩阵乘法(权重 * V)**:最后,注意力权重乘以值 ($V$) 张量。这会计算值向量的加权和,其中权重由注意力分布决定。结果是注意力方式的输出,表示输入序列中每个查询位置的上下文有关的信息得到突出。 该函数返回最终的输出张量和注意力权重,这对于分析和可视化很有帮助。 ### 使用示例 让我们创建一些示例张量,看看函数的效果。我们假设批大小为 1,序列长度为 4,嵌入维度 ($d_k$, $d_v$) 为 8。在自注意力中,Q、K 和 V 一般来源于相同的输入序列,因此 `seq_len_q`、`seq_len_k` 和 `seq_len_v` 通常是相同的。 ```python # 示例参数 batch_size = 1 seq_len = 4 d_k = 8 # 键/查询的维度 d_v = 8 # 值的维度 # 创建随机查询、值张量 # 在实际模型中,这些将来自通过线性层投影的输入嵌入 query = torch.randn(batch_size, seq_len, d_k) key = torch.randn(batch_size, seq_len, d_k) value = torch.randn(batch_size, seq_len, d_v) # 计算注意力 output, attention_weights = scaled_dot_product_attention(query, key, value) print("输入查询形状:", query.shape) print("输入键形状:", key.shape) print("输入值形状:", value.shape) print("\n输出形状:", output.shape) print("注意力权重形状:", attention_weights.shape) print("\n注意力权重样本(第一个批次元素):\n", attention_weights[0])您应该看到类似于此的输出(值会因随机性而有所不同):输入查询形状: torch.Size([1, 4, 8]) 输入键形状: torch.Size([1, 4, 8]) 输入值形状: torch.Size([1, 4, 8]) 输出形状: torch.Size([1, 4, 8]) 注意力权重形状: torch.Size([1, 4, 4]) 注意力权重样本(第一个批次元素): tensor([[0.1813, 0.3056, 0.3317, 0.1814], [0.2477, 0.2080, 0.3401, 0.2042], [0.2880, 0.1807, 0.2523, 0.2790], [0.3139, 0.1774, 0.2614, 0.2473]])请注意,输出形状 (1, 4, 8) 与查询和值序列长度以及值维度 ($d_v$) 相符。注意力权重形状 (1, 4, 4) 表示从 4 个查询位置中的每个位置到 4 个键位置中的每个位置的注意力分数。注意力权重样本中的每一行近似相加为 1。可视化注意力权重可视化注意力权重可以给出模型在处理特定元素时,对输入序列的哪些部分侧重的认知。让我们为刚刚计算的 attention_weights 使用一个简单的热力图。{"data": [{"z": [[0.1813, 0.3056, 0.3317, 0.1814], [0.2477, 0.2080, 0.3401, 0.2042], [0.2880, 0.1807, 0.2523, 0.2790], [0.3139, 0.1774, 0.2614, 0.2473]], "x": ["位置 1", "位置 2", "位置 3", "位置 4"], "y": ["查询位置 1", "查询位置 2", "查询位置 3", "查询位置 4"], "type": "heatmap", "hoverongaps": false, "colorscale": [[0.0, "#e9ecef"], [0.5, "#74c0fc"], [1.0, "#1c7ed6"]]}], "layout": {"title": "注意力权重示例", "xaxis": {"title": "位置"}, "yaxis": {"title": "查询位置", "autorange": "reversed"}, "width": 500, "height": 450}}注意力权重以热力图形式呈现。每个单元格 (i, j) 表示从查询位置 i 到键位置 j 的注意力权重。深蓝色表示更高的注意力。这个可视化显示了每个查询位置(行)对每个位置(列)的关注程度。在实际应用中,比如将“hello world”翻译成法语时,您可能会看到在生成法语单词“world”时,注意力方式主要集中在输入单词“world”上。小结在本节中,您实现了缩放点积注意力,它是 Transformer 中注意力功能的主要计算单元。您了解了如何计算查询和键之间的分数、对其进行缩放、另外可以选择应用掩码、使用 softmax 进行标准化以获得权重,最后计算值的加权和。这个函数是多头注意力方式中使用的构建单元,使得 Transformer 能够良好地处理序列信息。