趋近智
将节点嵌入 (embedding)转化为特定任务的预测结果是图神经网络 (neural network)应用中的一个核心步骤。对于节点分类等任务,这种转化通常涉及最后一个线性层。该层将每个节点的 D 维嵌入映射到一个 C 维向量 (vector),其中 C 代表类别数量。得到的输出向量包含每个类别的原始未归一化 (normalization)分数,通常被称为 logits。
为了训练模型,我们需要一种方法来衡量这些预测值与真实标签之间的差距。这就是损失函数 (loss function)(或目标函数)的作用。损失函数计算出一个标量值来量化 (quantization)模型的误差。训练的目标就是调整模型权重 (weight),使这个值降到最低。
对于多分类节点任务,最常用的损失函数 (loss function)是交叉熵损失 (Cross-Entropy Loss)。该函数是深度学习 (deep learning)分类问题的标准选择,在 GNN 中表现也非常出色。
它的运行分为两个阶段:
由于只有一个 为 1(即正确类别),其余均为 0,因此该公式简化为计算正确类别预测概率的负对数。正确类别的预测概率越高,损失就越低,这正是我们期望的结果。
在实际操作中,PyTorch 和 TensorFlow 等深度学习库提供了一个统一的函数(如 torch.nn.CrossEntropyLoss),它结合了 Softmax 激活和交叉熵计算。推荐使用这种组合函数,因为它比分开执行两个步骤具有更好的数值稳定性。
单个节点的损失计算过程。GNN 的输出嵌入 (embedding)通过分类器获得概率,然后使用损失函数将其与真实标签进行对比。
有时一个节点可以同时属于多个类别。例如,引用网络中的一篇研究论文可能同时属于“图神经网络 (neural network)”和“强化学习 (reinforcement learning)”。这是一个多标签分类问题,交叉熵损失不再适用,因为它假设每个节点仅属于一个类别。
针对这种情况,合适的选择是二元交叉熵 (BCE) 损失。设置上会有细微变化:
这里, 是 0 或 1(该类别的真实标签), 是来自 sigmoid 函数的预测概率。PyTorch 将其提供为 torch.nn.BCEWithLogitsLoss,它结合了 sigmoid 和 BCE 计算以提高稳定性。
虽然我们主要讨论了节点分类,但 GNN 还应用于需要不同损失函数的其他任务。
链路预测:该任务通常被设定为一个二分类问题:对于任意一对节点,它们之间是否存在边?你可以获取两个节点的最终嵌入 (embedding) ,通过点积等算子将它们结合,然后通过 sigmoid 函数预测存在链接的概率。接着使用二元交叉熵损失对照真实的图结构进行模型训练。
图分类:在该任务中,目标是为整个图分配一个标签。使用读出层或池化层将所有节点嵌入聚合成一个图级嵌入 。然后将此嵌入输入标准分类器。如果是多分类问题,你会像在节点分类中一样使用交叉熵损失。如果是回归任务(例如预测分子属性),你可能会使用回归损失,如均方误差 (MSE)。
选择哪种损失函数由任务输出的性质决定。其基本原理与深度学习 (deep learning)其他分支一致;主要区别在于预测结果源自 GNN 独特的结构。
这部分内容有帮助吗?
© 2026 ApX Machine LearningAI伦理与透明度•