关系网络 (RNs) 为少样本学习 (few-shot learning)引入了一种独特的方法。与计算类别原型并在嵌入 (embedding)空间中使用固定距离度量的原型网络不同,关系网络由 Sung 等人 (2018) 提出,直接学习一个深度非线性度量函数。此函数旨在确定查询样本和支持样本之间的相似性或“关系”。关系网络不假设欧氏距离这样的简单度量就足够,它们认为专用神经网络 (neural network)能更好地捕捉少样本分类所需的复杂关系,尤其是在处理复杂视觉或语义特征时。
关系网络结构
核心思想是训练一个网络,它能明确输出特征表示对之间的标量相似度分数。一个典型的关系网络包含两个主要部分:
-
嵌入 (embedding)模块 (fϕ): 这个模块,通常是用于视觉任务的卷积神经网络 (neural network) (CNN) 或用于序列数据的Transformer编码器,将原始输入(包括支持集样本 xs 和查询样本 xq)映射为特征嵌入。
嵌入s=fϕ(xs)
嵌入q=fϕ(xq)
在调整基础模型时,fϕ 可能是预训练 (pre-training)模型本身(可能冻结)或其一部分。
-
关系模块 (gψ): 这个模块接收嵌入对(通常是查询嵌入与支持嵌入结合),并输出一个介于 0 和 1 之间的标量关系分数,表示与分类任务相关的相似程度。对于一个 N 类、K 样本任务,一个常见做法是首先聚合每个类别 ci 的支持嵌入(例如,通过求和或求平均),以得到一个类别表示 si:
si=K1k=1∑Kfϕ(xs(i,k))
然后,关系模块处理查询嵌入 q=fϕ(xq) 和每个类别表示 si 的连接:
ri=gψ(组合(si,q))
组合 函数通常是简单的拼接。关系模块 gψ 本身通常是一个较小的神经网络,比如几个卷积层后面跟着全连接层,或者只是一个多层感知机 (MLP)。它被设计用于学习一个任务特定的相似度函数。
这是关系网络中一个类别 ci 的信息流。支持图像和查询图像被嵌入,支持嵌入被聚合,与查询嵌入组合后,输入到关系模块以产生相似度分数。此过程对支持集中的所有类别重复进行。
关系网络训练
关系网络按集训练,与其他元学习方法类似。在每一集中,会抽样一个任务(例如,一个 C 类、K 样本分类问题)。网络处理该任务的支持集和查询样本。目标函数通常旨在将关系分数 ri 推向 1(如果查询样本 xq 属于类别 ci),否则推向 0。常见的损失函数 (loss function)是均方误差 (MSE):
L=task T∑(xq,yq)∈Queries(T)∑i=1∑C(ri−1yq=ci)2
这里,ri 是给定查询 xq 对类别 ci 的预测关系分数,而 1yq=ci 是一个指示函数,如果真实标签 yq 与类别 ci 匹配则为 1,否则为 0。此损失通过关系模块 gψ 和嵌入 (embedding)模块 fϕ 反向传播 (backpropagation),使得特征表示和相似度函数能同时学习。
与原型网络的比较
关系网络与原型网络显著不同:
- 学习的度量: 关系网络明确学习相似度函数 (gψ),而原型网络在学习的嵌入 (embedding)空间中使用预定义度量(例如,平方欧氏距离)。
- 复杂性: 关系模块 gψ 能够建模高度非线性关系,可能比固定距离度量提供更多灵活性,尤其当嵌入空间中的底层类别结构复杂时。
- 计算: 关系网络需要将嵌入对(或查询-聚合对)通过关系模块逐类处理,这可能比原型网络中简单的距离计算在计算上更密集,尤其随着类别数 (C) 或样本数 (K) 增加时。
使用基础模型嵌入 (embedding)
在大型基础模型的背景下应用关系网络时,嵌入模块 fϕ 自然地被基础模型取代。
- 冻结嵌入: 一种策略是将基础模型用作固定的特征提取器。嵌入 fϕ(x) 被预计算或即时生成,并且在元学习期间仅训练关系模块 gψ。这种方法计算效率高,但完全依赖于基础模型默认嵌入空间的质量和适用性。
- 微调 (fine-tuning): 另一种方式是,在元训练期间,基础模型 fϕ 的一部分可以与 gψ 一起微调。这提供了更多灵活性,但显著增加计算复杂性,并需要仔细考虑稳定性,尤其考虑到元学习优化的双层性质。参数 (parameter)高效微调 (PEFT) 方法可能可以整合于此。
基础模型嵌入的高维度(例如 768、1024 或更多维度)在设计关系模块 gψ 时需要考虑。一个简单的 MLP 可能面临维度灾难或变得计算负担重。降维技术(如果在不损失重要信息的情况下可行)或精心构建的关系模块(例如,使用注意力机制 (attention mechanism)或因子分解)可能变得必要。
优点与局限
优点:
- 学习根据任务分布调整的、灵活的、可能非线性的相似度度量。
- 能够捕捉固定距离度量可能遗漏的复杂关系。
- 在引入时,在多个少样本基准测试上达到先进性能。
局限:
- 计算成本高于原型网络等方法,因为它每个查询-类别比较都需要通过关系模块进行前向传播。
- 性能对嵌入 (embedding)模块和关系模块的结构设计敏感。
- 训练有时不如简单的基于度量的方法稳定,需要仔细的超参数 (parameter) (hyperparameter)调整。
关系网络在基于度量的元学习中提供一个有力的替代方案,将复杂性从单纯学习良好嵌入,转移到联合学习嵌入和灵活的比较机制。它们在基础模型上的有效性取决于有效使用预训练 (pre-training)表示,同时管理学习到的关系模块的计算成本。