训练基础模型本身就已经对计算资源提出了很高要求,而元学习固有的嵌套优化循环和复杂的梯度计算则进一步加剧了这些困难。依赖单个计算节点通常不可行。将元学习过程分布到多个设备(GPU甚至多台机器)上变得非常必要,以处理基础模型的规模以及MAML等算法的计算需求,尤其是在处理元批次中的大量任务时。
然而,分布元学习并不像标准分布式深度学习(如简单数据并行)那样直接。元学习固有的结构,包括其内循环调整和外循环元更新,带来了独特的通信模式和同步要求。下面我们来看一下并行元学习计算的主要策略。
任务并行
这可以说是分布元学习最自然的方式。其主要思想是在一个元批次中并行处理不同的任务。由于元批次中每个任务 Ti 的内循环更新通常在元梯度计算之前是独立的,因此我们可以将不同的任务(或任务子集)分配给不同的工作单元。
工作原理:
- 分发: 中央协调器(或rank 0)将当前元批次中的任务 T1,T2,...,TN 分发给 W 个可用工作单元。每个工作单元 w 接收一个任务子集 {Ti}i∈Tasksw。
- 内循环计算: 每个工作单元 w 独立地为其分配的任务执行内循环优化步骤。对于任务 Ti,这包括从当前元参数 θ 开始计算更新后的参数 ϕi:
ϕi=内循环更新(θ,D支持,i)
这可能涉及在支持集 D支持,i 上执行一个或多个梯度步骤。
- 结果聚合: 完成内循环后,工作单元需要通信外循环更新所需的结果。这通常包括返回在查询集 D查询,i 上计算的任务特定梯度 ∇θLTi(ϕi,D查询,i),或者根据具体的元学习算法(例如MAML与Reptile)可能返回更新后的参数 ϕi 本身。可以使用像
all_gather 或发送到参数服务器等通信策略。
- 元更新: 中央协调器聚合所有工作单元的结果,并计算最终的元梯度 ∇θL元。例如,在MAML中,这近似于:
∇θL元≈N1i=1∑N∇θLTi(ϕi,D查询,i)
然后,此聚合梯度用于更新共享的元参数 θ。
元学习中任务并行的示意图。元批次中的任务被分发给工作单元,每个工作单元执行内循环更新。结果被发送回协调器,用于元梯度计算和参数更新。
优点:
- 直接并行化计算最耗时的部分:运行多个独立的任务适应过程。
- 能很好地随着每个元批次中的任务数量增加而扩展。
缺点:
- 通信可能成为瓶颈,尤其是在为元更新聚合梯度或参数时,对于大量工作单元或大型模型而言更是如此。
- 在元更新之前需要仔细同步。
- 如果任务的计算成本差异很大(例如,支持集大小或内循环步数不同),则可能出现负载不均衡。
任务内数据并行
任务并行是分发任务,而标准数据并行可以应用于单个任务的内循环内部,尤其是在支持集 D支持,i 足够大,或者基础模型的正向/反向传播对于单个任务来说计算量也很大的情况下。
工作原理:
- 分配: 任务 Ti 被分配给一组工作单元(这组工作单元可以是混合设置中处理其他任务的同一组,也可以是专用组)。
- 数据分片: 该任务的支持集 D支持,i 被分配到组内的工作单元上。
- 并行梯度计算: 在任务 Ti 的内循环的每个步骤中,每个工作单元使用相同的当前参数(无论是 θ 还是中间参数 ϕi(k)),基于其数据分片计算梯度。
- 梯度平均: 该内步骤中在工作单元之间计算的梯度会被平均,通常使用高效的集体通信原语,如
all_reduce。
- 内部参数更新: 平均后的梯度用于更新任务特定参数 ϕi。此过程重复所需数量的内步骤。
这种方法本质上是将标准分布式数据并行(DDP)重复应用于元学习的内循环中。当单个任务适应过程(在支持集上的正向/反向传播)本身是主要计算成本时,这种方法最为有效。
优点:
- 使用成熟的DDP技术和库(例如PyTorch的DistributedDataParallel、Horovod)。
- 当单个任务涉及大量数据或每一步计算量很大时,效果良好。
缺点:
- 增加通信频率,因为通过这种方式处理的每个任务,其梯度需要在每个内步骤进行同步。
- 对于支持集较小的典型少样本场景,效果不佳。
- 与任务并行结合使用时,管理起来可能更复杂。
模型并行(流水线/张量并行)
对于单个GPU无法容纳的基础模型,模型并行不是可选项,而是必要条件。这涉及将模型本身拆分到多个设备上。
- 流水线并行: 将模型层按顺序划分到不同设备上。激活值从一个设备传递到下一个设备。需要通过微批处理来管理“流水线气泡”(空闲时间)。
- 张量并行: 将单个权重矩阵和激活值拆分到不同设备上。矩阵乘法等操作需要在操作本身内部进行同步通信(例如
all_reduce)。Megatron-LM等框架实现了这些方法的复杂版本。
在元学习背景下,模型并行与任务并行和数据并行正交运作。它定义了当模型跨越多个GPU时,单个正向/反向传播(无论是内循环还是外循环)的计算如何执行。
与元学习的相互作用:
- 模型并行必须在参与任务并行或数据并行的每个工作单元内部实现。
- 它会显著增加梯度传播的复杂性,特别是对于MAML等需要通过内循环优化过程进行反向传播的二阶方法。计算跨模型分区的Hessian-向量积或完整二阶导数会增加大量的通信和同步开销。
- 一阶近似(FOMAML,Reptile)或隐式梯度方法(iMAML)变得更受青睐,因为它们避免或简化了二阶反向传播,使得与模型并行的结合更易于管理。
优点:
缺点:
- 实现复杂度高。
- 在每次正向/反向传播中引入通信开销。
- 可能加剧元梯度计算的复杂性。
混合方法
在实际应用中,为基础模型扩展元学习通常需要结合这些策略:
- 模型 + 任务并行: 使用模型并行(流水线和/或张量)将大型基础模型容纳到由一组GPU组成的单个“工作单元”上。然后,使用任务并行将不同的元学习任务分发到多个此类多GPU工作单元。这是大型实验的常见设置。
- 模型 + 任务 + 数据并行: 如果单个任务也涉及大量计算(例如,更大的支持集,许多内步骤),数据并行可以在处理特定任务的每个多GPU工作单元内部叠加使用。
这些混合方法需要复杂的编排、高效的通信库(如用于GPU集体通信的NCCL)以及细致的资源管理。
通信策略和优化
无论采用何种分布策略,高效通信都非常重要。
- 集体操作:
all_reduce、broadcast、reduce_scatter 和 all_gather 等原语是NCCL、MPI等库提供的基本构建块,并集成到深度学习框架中。为梯度平均(数据并行)或结果聚合(任务并行)选择正确的集体操作会影响性能。
- 参数服务器与去中心化: 尽管存在传统的参数服务器,但由于带宽利用率更高,使用
all_reduce 的去中心化方法通常更受GPU集群青睐,特别是对于深度学习中常见的密集梯度更新。
- 梯度累积: 为了降低通信频率,尤其是在任务并行中的外循环更新,可以将在多个微批次甚至多个元批次中的梯度在每个工作单元上本地累积,然后再执行全局归约和参数更新。这牺牲了梯度新鲜度以换取降低通信开销。
- 异步更新: 允许工作单元异步计算和发送更新可能会通过避免同步等待来提高吞吐量。然而,这会带来梯度陈旧性问题,可能使敏感的元优化过程变得不稳定。尽管存在潜在的性能限制,同步更新在元学习中通常更常见且更可靠。
选择并实现正确的分布式策略是一个复杂的平衡过程。这在很大程度上取决于元学习算法(MAML、ProtoNets或Reptile)、基础模型的大小、每个元批次中的任务数量、支持/查询集的大小以及可用的硬件基础设施(网络带宽、GPU互连)。DeepSpeed和Megatron-LM等框架提供了可以帮助管理模型并行和数据并行某些方面的工具,但要针对任务并行元学习的特定通信模式有效适配它们,通常需要定制化的实现工作。