元训练循环的效率很大程度上取决于任务的选择和元批次的组建方式。考虑到基础模型的规模,即使处理单个任务的支持集和查询集也可能计算量很大,因此,简单或低效的抽样策略会迅速变得成本过高。优化任务抽样和批处理直接解决了之前提到的计算瓶颈,影响训练时间和资源消耗。
任务抽样的作用
任务抽样指的是从元训练分布 (p(T)) 中选择哪些任务 (Ti) 以用于下一次元更新的过程。目的不只是随机选择任务,而是以一种能促成高效学习的方式进行。
- 代表性: 抽取的任务应理想地反映模型在元测试时预期适应的任务的真实分布。有偏差的抽样可能导致泛化能力差。
- 效率: 抽样应快速,所选任务应有效帮助提升模型的元学习目标。抽样过于简单或冗余的任务会浪费计算资源。
- 稳定性: 抽样任务的顺序会影响元学习过程的稳定性。元批次之间任务难度或类型的高差异可能导致元梯度不稳定。
抽样策略
可以采用多种策略替代简单的均匀随机抽样:
- 均匀随机抽样: 这是最直接的方法,其中元训练集中的每个任务被选中用于任何给定元批次的概率相等。尽管实现简单,但它可能不是最有效的,特别是如果任务分布包含许多简单或无信息量的任务。
- 任务的课程学习: 类似于标准监督训练中的课程学习,任务可以按结构化顺序呈现给元学习器,通常基于难度。元训练可能从较简单的任务(例如,更少的样本、更少的类别、较不复杂的数据)开始,并逐步引入更具挑战性的任务。这有助于稳定元训练的初始阶段,特别是对于像 MAML 这样复杂的基于梯度的方法,防止早期出现可能导致不稳定的过大梯度。确定任务难度本身可能是一个难题,可能需要基于启发式方法或初步运行。
- 多样性导向抽样: 为确保元学习器泛化良好,抽样可以明确地旨在最大化元批次内或一系列批次中任务的多样性。多样性可以根据任务元数据(例如,领域、数据集来源)或任务数据之间的嵌入相似性来衡量。这鼓励模型学习适用于各种情况的适应策略。
- 基于难度的抽样 / 难点任务挖掘: 优先处理元学习器当前表现不佳的任务。这将计算精力集中在最需要的地方。识别难点任务通常涉及使用当前模型状态评估一组候选任务的表现,然后优先从表现低于特定阈值的任务中抽样。这会增加评估步骤的计算开销,但通过集中处理问题区域,有可能加快收敛速度。
构建元批次
任务抽样后,它们被分组到元批次中。一个元批次通常由 B 个任务组成,即 {T1,T2,...,TB}. 对于每个任务 Ti,元批次包含其对应的支持集 DiS 和查询集 DiQ。这个元批次的结构和大小对计算和内存有直接影响。
元批处理的考量:
元批次大小 (B)、估计梯度方差与每个元步骤计算成本之间的关系。增大 B 可以减少方差,但会增加资源需求。
- 任务异质性: 元批次可能包含不同大小(不同的 Nk, Nq)或复杂度的任务。这需要仔细实现以处理潜在的负载不平衡以及填充/掩码策略,尤其是在使用偏好统一计算结构的 GPU 等硬件加速器时。动态地将具有相似计算特征的任务进行批处理可以缓解部分这些问题。
- 异步执行: 在分布式设置中,甚至在具有多个工作器的单台机器上,任务抽样和数据加载可以与计算重叠。一个元批次可以正在处理,而下一个元批次的数据正在获取和准备。这需要更复杂的调度,但可以通过隐藏数据加载延迟来显著提高吞吐量。
大规模部署中的实现
对于基础模型,高效的任务抽样和批处理变得更为重要:
- 内存主导: 即使是一个任务适应步骤所需的激活和梯度内存也可能相当大。这常常迫使使用非常小的元批次大小 (B),或依赖于跨微批次的梯度累积等技术,即元梯度是增量计算的。
- 数据处理: 元学习通常假设任务是现成的。对于大型数据集,高效地抽取样本以形成成千上万个任务的支持集和查询集,需要优化数据加载管道。如果任务分布是静态的,预处理和存储任务数据结构会很有用。
- 与分布式训练的互动: 任务批处理方式与分布式训练策略密切相关。如果使用任务并行,每个工作器可能处理元批次中的一个或多个任务。工作器之间的负载均衡变得重要,可能偏好生成计算成本相似的任务的抽样策略。
总而言之,在将元学习扩展到基础模型时,选择合适的任务抽样策略和精心构建元批次并非次要的实现细节。它们是管理计算资源、确保训练过程稳定,并最终实现有效少样本适应能力的根本考量。最佳策略通常需要在理论益处(例如,大批次带来的较低梯度方差)与实际硬件限制之间取得平衡,并且可能需要根据具体的模型架构、数据集特点和所使用的元学习算法进行经验性调整。