本节提供实现原型网络的实用指导,特别是采用大型基础模型学习到的有力表示。在基于度量的元学习理论探讨之上,我们将着重使用基础模型的预计算嵌入作为输入特征,用于在少样本设定中构建原型并分类查询实例。这种方法避免了整个网络的端到端元训练,使其计算高效,且常有出人意料的良好效果。先决条件与配置继续之前,请确保您已准备好以下组成部分:基础模型嵌入: 可使用能生成有意义嵌入的预训练基础模型,适配您的数据模态(例如,图像用ViT,文本用BERT)。您可以加载模型本身,或者更高效地预先计算并存储嵌入用于您的少样本数据集。本练习中,我们假定嵌入是固定的(即基础模型的权重已冻结)。少样本数据集: 构成N-way K-shot分类任务的数据集。每个任务应包含:支持集 $S = {(x_1, y_1), ..., (x_{N \times K}, y_{N \times K})}$,包含N个类别中每个类别的K个带标签样本。查询集 $Q = {(x^_1, y^1), ..., (x^*{N \times Q'}, y^*_{N \times Q'}) }$,包含N个类别中每个类别的新增Q'个样本,用于评估。环境: 一个配置有标准科学计算库(numpy、torch 或 tensorflow)的Python环境。我们将使用类似PyTorch的语法进行说明。主要实现步骤主要思路是将支持样本和查询样本都映射到基础模型提供的嵌入空间中,然后在该空间中执行最近原型分类。嵌入提取: 通过(冻结的)基础模型 $f_\phi(\cdot)$ 处理支持集样本 $x_i \in S$ 和查询集样本 $x^*_j \in Q$,以获取它们各自的嵌入:支持嵌入:对于所有 $(x_i, y_i) \in S$,有 $e_i = f_\phi(x_i)$。查询嵌入:对于所有 $(x^_j, y^_j) \in Q$,有 $e^j = f\phi(x^_j)$。由于基础模型较大,批处理并可能缓存这些嵌入对于效率很重要,特别是在元测试期间或相同样本出现在多个任务中时。获取嵌入的函数import torch @torch.no_grad() # 确保不计算梯度 def get_embeddings(foundation_model, data_loader, device): foundation_model.eval() # 将模型设置为评估模式 all_embeddings = [] all_labels = [] for inputs, labels in data_loader: inputs = inputs.to(device) # 假设基础模型直接输出嵌入或在投影层之后输出 embeddings = foundation_model(inputs) all_embeddings.append(embeddings.cpu()) all_labels.append(labels.cpu()) return torch.cat(all_embeddings), torch.cat(all_labels) # 在任务循环中的使用示例: # support_loader = # 支持集的数据加载器 (N*K 样本) # query_loader = # 查询集的数据加载器 (N*Q' 样本) # support_embeddings, support_labels = get_embeddings(model, support_loader, device) # query_embeddings, query_labels = get_embeddings(model, query_loader, device) ```2. 原型计算: 对于支持集中存在的每个类别 $k$,通过对其 $K$ 个支持样本的嵌入求平均来计算其原型 $c_k$: $$ c_k = \frac{1}{|S_k|} \sum_{(x_i, y_i) \in S_k} f_\phi(x_i) = \frac{1}{K} \sum_{i: y_i=k} e_i $$ 这会得到 $N$ 个原型向量,每个向量代表嵌入空间中一个类别的中心点。```python计算原型def calculate_prototypes(support_embeddings, support_labels, classes): # classes: 包含任务唯一类别标签的列表或张量 num_classes = len(classes) embedding_dim = support_embeddings.size(1) prototypes = torch.zeros(num_classes, embedding_dim, device=support_embeddings.device) for i, k in enumerate(classes): # 选择属于类别 k 的嵌入 class_mask = (support_labels == k) class_embeddings = support_embeddings[class_mask] # 计算平均嵌入 prototypes[i] = class_embeddings.mean(dim=0) return prototypes ```3. 查询分类: 通过根据所选距离度量 $d(\cdot, \cdot)$ 找到最近的原型 $c_k$ 来分类每个查询嵌入 $e^_j$。查询样本 $x^_j$ 的预测类别 $\hat{y}^_j$ 为: $$ \hat{y}^_j = \underset{k \in {1, ..., N}}{\operatorname{argmin}} ; d(e^*_j, c_k) $$ $d$ 的常见选择包括平方欧几里得距离或余弦距离。平方欧几里得距离因其计算简便性(避免开方)以及在 argmin 操作中与欧几里得距离的等效性而常被选用。```python分类查询点def classify_queries(query_embeddings, prototypes): # 使用平方欧几里得距离: ||a - b||^2 # `torch.cdist` 计算成对距离。 # query_embeddings 形状: [查询数量, 嵌入维度] # prototypes 形状: [类别数量, 嵌入维度] # dists 形状: [查询数量, 类别数量] dists = torch.cdist(query_embeddings, prototypes.unsqueeze(0)).squeeze(0) ** 2 # 为提高效率,使用广播 # 找到每个查询距离最小的索引(类别) # predictions 形状: [查询数量] predictions = torch.argmin(dists, dim=1) return predictions # 这些是相对于 'prototypes' 张量的索引 ``` 如有必要,请记住将预测索引映射回原始类别标签。基础模型嵌入的考量维度: 基础模型嵌入通常是高维的(例如,768、1024 或更多)。虽然原型网络能自然地处理这一点,但请注意距离计算的计算成本。归一化: 在计算原型和距离之前,对支持嵌入和查询嵌入进行L2范数归一化通常是有益的。这使得分类依赖于角度(余弦相似度)而非幅度,这在高维空间中可能具有优势,因为幅度可能差异很大。如果使用余弦相似度,归一化是固有的。对于欧几里得距离,归一化会将其转换为余弦相似度的函数:对于归一化向量 $u, v$,有 $||u-v||^2_2 = 2 - 2 u \cdot v$。度量选择: 虽然欧几里得距离是标准做法,但如果类别簇之间的角度分离比它们在嵌入空间中的位置更能提供信息,那么余弦距离(或相似度)可能更合适,尤其是在归一化之后。实验很重要。过程可视化以下图表说明了单一少样本任务的核心工作流程:digraph G { rankdir=LR; node [shape=box, style=rounded, fontname="sans-serif", fontsize=10]; edge [fontname="sans-serif", fontsize=9]; subgraph cluster_support { label = "支持集"; style=dashed; color="#adb5bd"; s1 [label="支持样本 1 (类别 A)"]; s2 [label="支持样本 2 (类别 A)"]; s3 [label="支持样本 K (类别 A)"]; s4 [label="支持样本 1 (类别 B)"]; s5 [label="支持样本 K (类别 B)"]; sn [label="... (N 个类别)"]; } subgraph cluster_query { label = "查询集"; style=dashed; color="#adb5bd"; q1 [label="查询样本 1"]; q2 [label="查询样本 M"]; } subgraph cluster_embed { label = "嵌入空间 (f_phi)"; style=filled; fillcolor="#e9ecef"; color="#adb5bd"; node [shape=ellipse, style=filled, fillcolor="#ffffff"]; se1 [label="e_1"]; se2 [label="e_2"]; se3 [label="e_K (A)"]; se4 [label="e_1 (B)"]; se5 [label="e_K (B)"]; sen [label="..."]; qe1 [label="e*_1"]; qe2 [label="e*_M"]; } subgraph cluster_proto { label = "原型计算"; style=filled; fillcolor="#dee2e6"; color="#adb5bd"; node [shape=diamond, style=filled, fillcolor="#a5d8ff"]; pa [label="原型 A\n(A 类嵌入的平均值)"]; pb [label="原型 B\n(B 类嵌入的平均值)"]; pn [label="..."]; } subgraph cluster_classify { label = "分类"; style=filled; fillcolor="#e9ecef"; color="#adb5bd"; node [shape=invhouse, style=filled, fillcolor="#ffec99"]; class1 [label="分类 e*_1\n(到原型的最小距离)"]; class2 [label="分类 e*_M\n(到原型的最小距离)"]; } # Edges s1 -> se1; s2 -> se2; s3 -> se3; s4 -> se4; s5 -> se5; sn -> sen; {s1, s2, s3, s4, s5, sn} -> foundation_model [label="应用 f_phi", style=invis]; # Grouping edge q1 -> qe1; q2 -> qe2; {q1, q2} -> foundation_model_q [label="应用 f_phi", style=invis]; # Grouping edge {se1, se2, se3} -> pa [lhead=cluster_proto, arrowhead=normal, style=dashed, color="#495057"]; {se4, se5} -> pb [lhead=cluster_proto, arrowhead=normal, style=dashed, color="#495057"]; sen -> pn [lhead=cluster_proto, arrowhead=normal, style=dashed, color="#495057"]; qe1 -> class1 [label="距离(e*_1, 原型 A)\n距离(e*_1, 原型 B)\n...", color="#1c7ed6"]; qe2 -> class2 [label="距离(e*_M, 原型 A)\n距离(e*_M, 原型 B)\n...", color="#1c7ed6"]; pa -> class1 [style=dotted, color="#1c7ed6"]; pb -> class1 [style=dotted, color="#1c7ed6"]; pn -> class1 [style=dotted, color="#1c7ed6"]; pa -> class2 [style=dotted, color="#1c7ed6"]; pb -> class2 [style=dotted, color="#1c7ed6"]; pn -> class2 [style=dotted, color="#1c7ed6"]; # Invisible nodes for layout help foundation_model [style=invis]; foundation_model_q [style=invis]; {rank=same; s1; q1;} {rank=same; foundation_model; foundation_model_q;} {rank=same; se1; qe1;} {rank=same; pa;} {rank=same; class1;} }流程图说明了使用预计算的基础模型嵌入的单一N-way K-shot任务中原型网络的处理过程。支持和查询样本被嵌入,原型从支持嵌入中计算,查询嵌入根据与原型的接近程度进行分类。评估与后续步骤通过计算在多个保留任务(元测试)的查询集上所做预测的准确性(或其他相关度量)来评估性能。这个实用实现展现了将强效预训练表示与原型网络简洁而有效的方法结合起来的效率。虽然我们在此使用了固定嵌入,但更先进的方法可能涉及元学习嵌入函数本身,或在元训练阶段微调基础模型的部分内容,这些主题将在本课程的其他部分进行探讨。比较使用不同距离度量或归一化策略获得的结果,可以提供对基础模型嵌入空间几何形状的更多理解,用于少样本学习。