趋近智
将元学习,特别是基于梯度的方法,应用于基础模型带来了重大的计算障碍,主要源于元梯度的计算。尽管元学习的架构包含一个外循环,根据内循环适应后的性能来优化元参数 (θ),但计算外循环梯度 (∇θLmeta) 的实际操作既复杂又资源密集,尤其是在大规模应用时。
主要困难在于元目标 Lmeta 是使用参数 ϕi 进行评估的,而 ϕi 本身是应用于元参数 θ 的优化过程(内循环更新)的结果。以 MAML 这样典型的基于梯度的元学习配置为例。对于单个任务 i,经过一步梯度下降后,适应的参数 ϕi 如下:
ϕi=θ−α∇θLtrain,i(θ)这里,Ltrain,i 是任务 i 支持集上的损失,α 是内循环学习率。元目标通常是元批次中任务查询集上的平均损失,使用这些适应的参数进行评估:Lmeta=∑iLtest,i(ϕi)。
为了更新元参数 θ,我们需要梯度 ∇θLmeta。应用链式法则得到:
∇θLmeta=i∑∇θLtest,i(ϕi)=i∑∇ϕiLtest,i(ϕi)⋅∇θϕi复杂度源于雅可比项 ∇θϕi:
∇θϕi=∇θ(θ−α∇θLtrain,i(θ))=I−α∇θ2Ltrain,i(θ)将其代回,显示出对内循环目标的海森矩阵 ∇θ2Ltrain,i(θ) 的依赖:
∇θLmeta=i∑(I−α∇θ2Ltrain,i(θ))T∇ϕiLtest,i(ϕi)这种对二阶导数的明确依赖是 MAML 等算法计算开销的主要来源。
计算和存储一个拥有数十亿参数 (d) 的基础模型的完整海森矩阵在计算上是不可行的,需要 O(d2) 的内存和计算量。MAML 的实际实现会避免形成完整的海森矩阵。相反,它们计算海森向量积 (HVP):(∇θ2Ltrain,i(θ))⋅v,这里的向量 v 为 ∇ϕiLtest,i(ϕi)。
这个 HVP 可以高效地计算,无需实例化完整的海森矩阵,通常使用与自动微分相关的技术(例如,第二次反向传播或结合正向和反向传播)。然而,即使计算 HVP 也会带来显著的开销:
计算量增加: 与标准的 first-order 梯度计算相比,它需要模型进行额外的正向和/或反向传播。单个元梯度步骤涉及多个内循环步骤(每个步骤都是一次正向/反向传播),然后是元梯度计算,元梯度计算本身又涉及评估 ∇ϕiLtest,i(又一次反向传播),以及随后的 HVP 计算。这会使每次外循环更新的计算负担大幅增加。
内存占用: 最主要的瓶颈通常是内存。为了计算元梯度,特别是二阶项,自动微分框架需要维护整个内循环优化过程的计算图。这包括正向传播中的激活以及内循环更新中可能使用的中间梯度。如果内循环涉及多个梯度步骤,此图的大小会随着步骤数线性增长。对于像 Transformer 这样的深度网络,每层的激活会消耗相当大的内存。存储应用于数十亿参数模型的多个适应步骤的图,会迅速耗尽标准加速器(GPU/TPU)上的可用内存。
下图说明了单步 MAML 元梯度计算中的数据流和依赖关系,强调了需要保留计算图的位置。
用于计算单个任务 MAML 元梯度的计算图依赖关系。内循环计算(正向传播、反向传播、参数更新)必须保留在内存中(橙色虚线),以计算二阶外循环梯度所需的海森向量积 (HVP)。
认识到这些困难,开发了 MAML 的一阶近似方法,例如 FOMAML 和 Reptile。FOMAML 明确忽略了二阶项,将元梯度近似为:
∇θLmeta≈i∑∇ϕiLtest,i(ϕi)这种计算只需相对于适应参数 ϕi 计算测试损失的标准梯度,将 ϕi 视为一个独立参数而非 θ 的函数。这完全避免了 HVP 计算,并大幅减少了内存需求,因为内循环的计算图不需要为外循环的反向传播保留。Reptile 通过一种与有限差分相关的不同更新规则达到了类似的效果。
尽管计算成本低得多,但一阶方法代表了一种近似。与二阶方法相比,它们可能表现出不同的收敛动态,并可能导致不同的解,尽管在许多实际情况下,特别是在大型模型和适当调整下,它们的表现非常出色。
总而言之,元梯度的计算,特别是 MAML 中涉及二阶导数的那些计算,在处理基础模型时带来了主要的计算和内存瓶颈。通过内循环优化过程进行微分的需求,要么需要昂贵的海森向量积,要么以牺牲近似精度为代价打破计算图依赖。理解这些困难是开发策略的第一步,这些策略将在后续讨论,以使元学习在大规模应用中变得可行。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造