趋近智
GNN 架构是强有力的特征提取器,能够将每个节点的结构和属性信息转化为稠密的向量嵌入。这些嵌入(通常表示为矩阵 )包含的高层表示,在下游任务中比原始输入特征更有用。然而,GNN 本身并不直接输出类别标签。为了执行节点分类等任务,必须在 GNN 模型中添加最后一个组件,将这些学习到的嵌入映射到类别预测上。
这最后一个组件通常被称为分类头 (classification head)。在大多数常见的 GNN 应用中,这只是一个以节点嵌入作为输入的标准前馈神经网络。在最简单且最常见的形式中,分类头是一个单一的线性层,不包含任何额外的隐藏层或非线性激活函数。
这个线性层的作用是充当一个可训练的分类器。它接收维度为 的嵌入(其中 是 GNN 编码器的输出维度),并将其投影到大小为 的向量中(其中 是数据集中的类别总数)。该输出向量中的每个元素代表特定类别的原始、未归一化的分数。这些分数通常被称为 logits。
节点分类的完整模型流程可以看作一个两阶段的过程:
这种结构使得模型在训练期间能够同时学习图表示和分类任务。
节点分类的端到端架构。GNN 编码器生成嵌入,然后将其传递给简单的线性分类器以产生最终的类别 logits。
让我们使用 PyTorch 将此架构转换为 Python 类。假设我们要构建一个用于分类的两层图卷积网络 (GCN)。该模型类将包含用于编码的 GCN 层和用于分类的标准 torch.nn.Linear 层。
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
class GCNNodeClassifier(nn.Module):
"""用于节点分类的两层 GCN 模型。"""
def __init__(self, in_channels, hidden_channels, num_classes):
super(GCNNodeClassifier, self).__init__()
# GNN 编码器层
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, hidden_channels)
# 分类头
self.classifier = nn.Linear(hidden_channels, num_classes)
def forward(self, x, edge_index):
# 1. GNN 编码器:获取节点嵌入
# 第一层 GCN
h = self.conv1(x, edge_index)
h = F.relu(h)
# 第二层 GCN
h = self.conv2(h, edge_index)
# 最终的节点嵌入现在存储在 'h' 中
# 2. 分类头:生成 logits
output = self.classifier(h)
return output
在此实现中:
in_channels:输入节点特征的维度(例如,Cora 数据集为 1433)。hidden_channels:GNN 层产生的节点嵌入维度。这是一个可以调整的超参数。num_classes:数据集中不同节点标签的数量(例如,Cora 数据集为 7)。forward(self, x, edge_index):此方法定义了计算流程。输入节点特征 x 和图结构 edge_index 通过两个 GCN 层,中间应用了 ReLU 激活函数。得到的嵌入 h 随后传递给 self.classifier 层以获得最终的 logits。许多节点分类任务的一个显著特点是它们在半监督(或更准确地说,转导式 (transductive))环境下运行。这意味着虽然我们只有一小部分节点的标签(训练集),但 GNN 编码器会使用整个图结构(包括所有节点和边)来生成嵌入。
在转导式设置中,GNN 模型在训练期间可以访问图中所有节点的特征和连接关系,甚至是验证集和测试集中的节点。模型的任务是为这个已知图中的无标签节点预测标签。
我们上面定义的 GCNNodeClassifier 就是为此设计的。它的 forward 方法为图中的每一个节点计算嵌入和 logits。在下一节讨论损失函数时,我们将看到如何使用掩码来确保模型的误差仅根据有标签训练节点的预测结果来计算。
有了这个端到端的模型结构,接下来的任务是定义一个目标函数来衡量其表现并指导其学习过程。这就引出了损失函数的话题。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造