趋近智
训练好的图神经网络 (neural network) (GNN) 依靠其损失函数 (loss function)来指导模型权重 (weight)的调整,从而达到有效的配置。然而,在训练集上获得较低的损失值并不总是能直接转化为在未见数据上的良好表现。为了真正了解模型的泛化能力,我们需要在预留的测试集上使用比原始损失值更易读的指标进行评估。对于节点分类,这些指标回答了一个简单的问题:模型正确标记 (token)了多少个节点?
用于节点分类的 GNN 通常以最后一个线性层结束,该层会为每个类别生成原始的、未归一化 (normalization)的分数,通常称为 logits。如果模型有 个类别,并且正在评估 个节点,则输出将是一个形状为 的张量。为了将这些分数转化为明确的预测,我们为每个节点选择分数最高的类别。这是通过沿类别维度执行 argmax 操作来完成的。
例如,如果单个节点的输出为 [0.1, 2.5, -1.3],则 argmax 为 1,这意味着模型预测为第二个类别(因为索引是从零开始的)。这个过程将模型的连续输出分数转换为离散的类别标签,以便我们与真实标签进行对比。
在计算任何指标之前,我们必须首先将模型的预测与真实标签进行比较。组织这种对比最基本的工具是混淆矩阵。在二分类任务中,混淆矩阵是一个 2x2 的表格,总结了预测的四种可能结果。
二分类混淆矩阵中的四种结果。正确的预测(TP, TN)位于主对角线上。
对于节点分类(通常是多分类问题),混淆矩阵会扩展为 的矩阵,其中 是类别的数量。第 行和第 列的条目表示真正属于类别 但被预测为类别 的节点数量。对角线元素代表所有被正确分类的节点。
准确率是最直观的指标。它衡量了正确预测占总预测的比例。
虽然准确率易于理解,但在类别不平衡的数据集上可能会产生误导。想象一个图,其中 95% 的节点属于类别 A,5% 属于类别 B。一个总是预测类别 A 的偷懒模型将获得 95% 的准确率,而没有学到关于类别 B 的任何有用信息。在这种情况下,我们需要更敏锐的指标。
为了更好地了解模型在不平衡数据上的表现,我们转向精确率、召回率和 F1 分数。这些指标通常是按类别计算的。对于给定类别,我们将其视为“正”类,而将所有其他类别视为“负”类。
精确率回答了这样一个问题:“在模型标记 (token)为类别 A 的所有节点中,有多少实际上是类别 A?”它衡量了模型正类预测的可靠性。
当假正类的代价很高时,高精确率非常。例如,在自动将学术论文标记为“撤稿”的系统中,你需要高精确率以避免错误地标记合法的论文。
召回率(也称为灵敏度或真正类率)回答了这样一个问题:“在所有真正属于类别 A 的节点中,模型找出了多少个?”它衡量了模型识别所有相关实例的能力。
当假负类的代价很高时,高召回率非常。例如,在预测哪些蛋白质与某种疾病相关的 GNN 中,你需要高召回率以避免错过任何潜在的重要蛋白质。
通常,精确率和召回率之间存在权衡。F1 分数提供了一种将两者结合成单个指标的方法。它是精确率和召回率的调和平均数,它会给较小的值赋予更大的权重 (weight)。这意味着只有当精确率和召回率都很高时,F1 分数才会很高。
由于精确率、召回率和 F1 是按类别计算的,我们需要一种策略将它们汇总为多分类节点分类任务的单个数值。两种最常用的方法是宏平均和加权平均。
宏平均 (Macro Average): 独立计算每个类别的指标,然后计算其算术平均值。这种方法将每个类别视为同等重要,而不考虑它包含多少个节点。如果你想知道模型在所有类别(包括稀有类别)上的表现,这是一个很好的衡量标准。
加权平均 (Weighted Average): 为每个类别计算指标,但在平均时,根据每个类别的支持数(该类别的真实实例数量)对每个类别的分数进行加权。这考虑了类别不平衡的情况。高加权平均 F1 分数表明模型在最常见的类别上表现良好。
在不平衡的数据集上,较高的加权平均分数可能会掩盖在少数类上的糟糕表现。较低的宏平均分数则反映了这一弱点。
评估指标的选择完全取决于你的应用目标。
在实践中,查看完整的分类报告通常很有帮助,该报告会单独显示每个类别的精确率、召回率和 F1 分数。像 scikit-learn 这样的库为此提供了方便的函数。
from sklearn.metrics import classification_report
# y_true: 真实标签(例如来自 data.test_mask)
# y_pred: 模型在测试集上的预测标签
# 假设 class_names 是标签的字符串列表
print(classification_report(y_true, y_pred, target_names=class_names))
# 精确率 召回率 f1-分数 支持数
# 类别 1 0.91 0.95 0.93 105
# 类别 2 0.75 0.82 0.78 80
# 类别 3 0.98 0.96 0.97 150
#
# 准确率 0.92 335
# 宏平均 0.88 0.91 0.89 335
# 加权平均 0.92 0.92 0.92 335
这份详细的报告让你能全面了解模型的优缺点,从而针对如何改进模型做出明智的决策。
这部分内容有帮助吗?
classification_report 和不同的平均策略,与Python代码片段直接相关。© 2026 ApX Machine LearningAI伦理与透明度•