趋近智
本节提供实现原型网络的实用指导,特别是采用大型基础模型学习到的有力表示。在基于度量的元学习理论探讨之上,我们将着重使用基础模型的预计算嵌入作为输入特征,用于在少样本设定中构建原型并分类查询实例。这种方法避免了整个网络的端到端元训练,使其计算高效,且常有出人意料的良好效果。
继续之前,请确保您已准备好以下组成部分:
numpy、torch 或 tensorflow)的Python环境。我们将使用类似PyTorch的语法进行说明。主要思路是将支持样本和查询样本都映射到基础模型提供的嵌入空间中,然后在该空间中执行最近原型分类。
嵌入提取: 通过(冻结的)基础模型 fϕ(⋅) 处理支持集样本 xi∈S 和查询集样本 xj∗∈Q,以获取它们各自的嵌入:
由于基础模型较大,批处理并可能缓存这些嵌入对于效率很重要,特别是在元测试期间或相同样本出现在多个任务中时。
import torch
@torch.no_grad() # 确保不计算梯度
def get_embeddings(foundation_model, data_loader, device):
foundation_model.eval() # 将模型设置为评估模式
all_embeddings = []
all_labels = []
for inputs, labels in data_loader:
inputs = inputs.to(device)
# 假设基础模型直接输出嵌入或在投影层之后输出
embeddings = foundation_model(inputs)
all_embeddings.append(embeddings.cpu())
all_labels.append(labels.cpu())
return torch.cat(all_embeddings), torch.cat(all_labels)
# 在任务循环中的使用示例:
# support_loader = # 支持集的数据加载器 (N*K 样本)
# query_loader = # 查询集的数据加载器 (N*Q' 样本)
# support_embeddings, support_labels = get_embeddings(model, support_loader, device)
# query_embeddings, query_labels = get_embeddings(model, query_loader, device)
```
2. 原型计算: 对于支持集中存在的每个类别 k,通过对其 K 个支持样本的嵌入求平均来计算其原型 ck: ck=∣Sk∣1∑(xi,yi)∈Skfϕ(xi)=K1∑i:yi=kei 这会得到 N 个原型向量,每个向量代表嵌入空间中一个类别的中心点。
```python
def calculate_prototypes(support_embeddings, support_labels, classes):
# classes: 包含任务唯一类别标签的列表或张量
num_classes = len(classes)
embedding_dim = support_embeddings.size(1)
prototypes = torch.zeros(num_classes, embedding_dim, device=support_embeddings.device)
for i, k in enumerate(classes):
# 选择属于类别 k 的嵌入
class_mask = (support_labels == k)
class_embeddings = support_embeddings[class_mask]
# 计算平均嵌入
prototypes[i] = class_embeddings.mean(dim=0)
return prototypes
```
3. 查询分类: 通过根据所选距离度量 d(⋅,⋅) 找到最近的原型 ck 来分类每个查询嵌入 ej∗。查询样本 xj∗ 的预测类别 y^j∗ 为: y^j∗=k∈{1,...,N}argmind(ej∗,ck) d 的常见选择包括平方欧几里得距离或余弦距离。平方欧几里得距离因其计算简便性(避免开方)以及在 argmin 操作中与欧几里得距离的等效性而常被选用。
```python
def classify_queries(query_embeddings, prototypes):
# 使用平方欧几里得距离: ||a - b||^2
# `torch.cdist` 计算成对距离。
# query_embeddings 形状: [查询数量, 嵌入维度]
# prototypes 形状: [类别数量, 嵌入维度]
# dists 形状: [查询数量, 类别数量]
dists = torch.cdist(query_embeddings, prototypes.unsqueeze(0)).squeeze(0) ** 2 # 为提高效率,使用广播
# 找到每个查询距离最小的索引(类别)
# predictions 形状: [查询数量]
predictions = torch.argmin(dists, dim=1)
return predictions # 这些是相对于 'prototypes' 张量的索引
```
如有必要,请记住将预测索引映射回原始类别标签。
以下图表说明了单一少样本任务的核心工作流程:
流程图说明了使用预计算的基础模型嵌入的单一N-way K-shot任务中原型网络的处理过程。支持和查询样本被嵌入,原型从支持嵌入中计算,查询嵌入根据与原型的接近程度进行分类。
通过计算在多个保留任务(元测试)的查询集上所做预测的准确性(或其他相关度量)来评估性能。
这个实用实现展现了将强效预训练表示与原型网络简洁而有效的方法结合起来的效率。虽然我们在此使用了固定嵌入,但更先进的方法可能涉及元学习嵌入函数本身,或在元训练阶段微调基础模型的部分内容,这些主题将在本课程的其他部分进行探讨。比较使用不同距离度量或归一化策略获得的结果,可以提供对基础模型嵌入空间几何形状的更多理解,用于少样本学习。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造