理解训练好的Transformer模型如何进行预测,与优化其训练速度或内存使用同等重要。分析注意力权重分布能够帮助我们了解模型的内部推理过程。通过它,我们能观察模型在处理特定标记时,将注意力集中在输入序列的哪些部分。这种做法对于调试、验证模型行为以及弄清模型是学到了有意义的关联还是依赖于虚假相关性非常有帮助。获取注意力权重大多数现代深度学习框架都提供了机制来访问模型中的中间激活,包括注意力权重。具体方法取决于所使用的框架以及Transformer模型的实现方式。常见的方法包括:修改forward传递: 修改模型的forward方法(或等效方法),使其在返回主要输出的同时也返回注意力权重。如果您能控制模型的源代码,这通常很简单。使用钩子(Hooks): PyTorch等框架提供“钩子”,可以注册在特定模块(如Attention层)上。这些钩子可以在正向或反向传播过程中捕获模块的输出(或输入),而无需永久修改模型代码。模型配置: 某些预构建的Transformer实现(例如,来自Hugging Face Transformers等库)带有一个配置标志(如output_attentions=True),可以指示模型返回注意力权重。无论采用哪种方法,目标都是获取注意力概率矩阵,这些矩阵通常是在每个注意力头内部的softmax操作之后计算得到的。对于标准的多头注意力层,输出的注意力权重通常具有[batch_size, num_heads, sequence_length_query, sequence_length_key]的形状。对于自注意力,sequence_length_query和sequence_length_key是相同的(即输入序列长度)。对于解码器中的交叉注意力,sequence_length_key对应于编码器输出序列长度。假设您已经获取了批处理中单个样本的特定层和头的注意力权重。结果张量,我们称之为attention_probs,可能具有[num_heads, seq_len, seq_len]的形状。我们可以选择一个特定的头进行可视化,从而得到一个形状为[seq_len, seq_len]的二维矩阵。使用热力图可视化注意力可视化这些二维注意力矩阵最常用的方法是使用热力图。热力图中的每个单元格 $(i, j)$ 表示从第$i$个查询标记到第$j$个键标记的注意力权重。值越高(颜色越亮)表示注意力越强。考虑一个简单的输入句子:“The quick brown fox jumps”。我们来可视化第一层、第一个头的自注意力权重。# 使用Matplotlib/Seaborn的Python代码 # 假设'attention_matrix'是一个[seq_len, seq_len]的numpy数组 # 假设'tokens'是一个字符串列表:["The", "quick", "brown", "fox", "jumps"] import matplotlib.pyplot as plt import seaborn as sns import numpy as np # 示例注意力数据(请替换为实际提取的权重) # 行:查询标记(源自) # 列:标记(指向) attention_matrix = np.random.rand(5, 5) # 使对角线略微加强以增加真实感 np.fill_diagonal(attention_matrix, attention_matrix.diagonal() + 0.3) # 对行进行归一化,使其总和为1(如softmax输出) attention_matrix /= attention_matrix.sum(axis=1, keepdims=True) tokens = ["The", "quick", "brown", "fox", "jumps"] seq_len = len(tokens) plt.figure(figsize=(7, 6)) sns.heatmap(attention_matrix, xticklabels=tokens, yticklabels=tokens, cmap="viridis", annot=True, fmt=".2f") plt.xlabel("标记(指向)") plt.ylabel("查询标记(源自)") plt.title("自注意力权重(第一层,第一个头)") plt.xticks(rotation=45) plt.yticks(rotation=0) plt.tight_layout() plt.show() 这是使用Plotly进行网页可视化的一个示例,显示了相同句子的注意力权重。{"layout": {"title": "自注意力权重(第一层,第一个头)", "xaxis": {"title": "标记(指向)", "tickangle": 45}, "yaxis": {"title": "查询标记(源自)", "autorange": "reversed"}, "width": 600, "height": 500, "margin": {"l": 80, "r": 50, "b": 100, "t": 50}}, "data": [{"z": [[0.5, 0.1, 0.1, 0.1, 0.2], [0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.5, 0.2, 0.1], [0.1, 0.2, 0.2, 0.4, 0.1], [0.1, 0.1, 0.1, 0.2, 0.5]], "x": ["The", "quick", "brown", "fox", "jumps"], "y": ["The", "quick", "brown", "fox", "jumps"], "type": "heatmap", "colorscale": "Viridis", "colorbar": {"title": "注意力"}}]}热力图显示了标记之间的注意力分数。行代表生成查询的标记,列代表生成键/值的标记。更亮的单元格表示更高的注意力分数。解读注意力模式在分析这些可视化时,寻找重复出现的模式:自身注意力: 通常,标记会强烈关注自身(自注意力中对角线很强)。这有助于保持标记本身的表示。局部上下文: 相邻标记之间存在高注意力分数。这表明注意力头侧重于局部词语依赖性。句法依赖: 寻找反映语法的模式。例如,动词可能会强烈关注其主语或宾语,即使它们在序列中相距较远。形容词可能会关注它们所修饰的名词。特殊标记: 观察对[CLS]、[SEP]、[BOS]、[EOS]等特殊标记的注意力情况。有时,[CLS]标记(在类似BERT的模型中)会汇总整个序列的信息,显示出广泛的注意力。在解码器中,标记通常会关注前一个句子或片段的[EOS](序列结束)标记。头部专长: 为同一层内不同的注意力头生成热力图。您会经常发现截然不同的模式。一个头可能侧重于局部上下文,另一个侧重于特定的句法关系(如名词-动词配对),而另一个可能表现出更广泛、几乎一致的注意力。这种专长是多头注意力的动因。层级深度: 较低层中的模式通常更具局部性,并侧重于句法结构。在较高层中,注意力模式可以变得更抽象,可能反映语义关系或聚合更长距离的信息。交叉注意力(解码器): 当可视化编码器-解码器交叉注意力时(例如,在翻译中),寻找源语言标记(键)和目标语言标记(查询)之间的强对齐。理想情况下,目标词应该强烈关注其对应的源词。这里的错位可能表明翻译错误。考量与局限性尽管富有洞察力,但注意力可视化并非对模型行为的明确解释。相关性与因果性: 高注意力并不证明某个标记是特定输出的唯一或主要原因。它表明了信息流的强度。平均效应: 对注意力头之间权重进行平均可能会模糊单个注意力头学习到的专门模式。基于梯度的方法: 注意力权重反映的是正向传递。其他方法(如基于梯度的显著图)提供了补充性的视角,侧重于哪些输入对输出预测影响最大。Softmax之后: 这些权重是softmax函数产生的概率。Softmax之前的分数可能显示出不同的相对重要性,尽管概率更容易解释。分析注意力权重是一种实用技术,可帮助您定性地了解Transformer模型。它通过帮助您理解模型如何处理信息来补充定量评估指标,这对于构建更有效和可靠的系统至关重要。这种做法有助于确认复杂机制(如学习到的位置嵌入、层归一化策略或特定优化器)是否带来了合理的内部表示和信息流。