将元学习,特别是基于梯度的方法,应用于基础模型带来了重大的计算障碍,主要源于元梯度的计算。尽管元学习的架构包含一个外循环,根据内循环适应后的性能来优化元参数 ($\theta$),但计算外循环梯度 ($\nabla_\theta \mathcal{L}_{\text{meta}}$) 的实际操作既复杂又资源密集,尤其是在大规模应用时。通过适应过程进行微分的开销主要困难在于元目标 $\mathcal{L}_{\text{meta}}$ 是使用参数 $\phi_i$ 进行评估的,而 $\phi_i$ 本身是应用于元参数 $\theta$ 的优化过程(内循环更新)的结果。以 MAML 这样典型的基于梯度的元学习配置为例。对于单个任务 $i$,经过一步梯度下降后,适应的参数 $\phi_i$ 如下:$$ \phi_i = \theta - \alpha \nabla_\theta \mathcal{L}_{\text{train}, i}(\theta) $$这里,$\mathcal{L}{\text{train}, i}$ 是任务 $i$ 支持集上的损失,$\alpha$ 是内循环学习率。元目标通常是元批次中任务查询集上的平均损失,使用这些适应的参数进行评估:$\mathcal{L}{\text{meta}} = \sum_i \mathcal{L}_{\text{test}, i}(\phi_i)$。为了更新元参数 $\theta$,我们需要梯度 $\nabla_\theta \mathcal{L}_{\text{meta}}$。应用链式法则得到:$$ \nabla_\theta \mathcal{L}{\text{meta}} = \sum_i \nabla\theta \mathcal{L}{\text{test}, i}(\phi_i) = \sum_i \nabla{\phi_i} \mathcal{L}{\text{test}, i}(\phi_i) \cdot \nabla\theta \phi_i $$复杂度源于雅可比项 $\nabla_\theta \phi_i$:$$ \nabla_\theta \phi_i = \nabla_\theta (\theta - \alpha \nabla_\theta \mathcal{L}{\text{train}, i}(\theta)) = I - \alpha \nabla^2{\theta} \mathcal{L}_{\text{train}, i}(\theta) $$将其代回,显示出对内循环目标的海森矩阵 $\nabla^2_{\theta} \mathcal{L}_{\text{train}, i}(\theta)$ 的依赖:$$ \nabla_\theta \mathcal{L}{\text{meta}} = \sum_i (I - \alpha \nabla^2{\theta} \mathcal{L}{\text{train}, i}(\theta))^T \nabla{\phi_i} \mathcal{L}_{\text{test}, i}(\phi_i) $$这种对二阶导数的明确依赖是 MAML 等算法计算开销的主要来源。二阶导数与内存消耗计算和存储一个拥有数十亿参数 ($d$) 的基础模型的完整海森矩阵在计算上是不可行的,需要 $O(d^2)$ 的内存和计算量。MAML 的实际实现会避免形成完整的海森矩阵。相反,它们计算海森向量积 (HVP):$(\nabla^2_{\theta} \mathcal{L}{\text{train}, i}(\theta)) \cdot v$,这里的向量 $v$ 为 $\nabla{\phi_i} \mathcal{L}_{\text{test}, i}(\phi_i)$。这个 HVP 可以高效地计算,无需实例化完整的海森矩阵,通常使用与自动微分相关的技术(例如,第二次反向传播或结合正向和反向传播)。然而,即使计算 HVP 也会带来显著的开销:计算量增加: 与标准的 first-order 梯度计算相比,它需要模型进行额外的正向和/或反向传播。单个元梯度步骤涉及多个内循环步骤(每个步骤都是一次正向/反向传播),然后是元梯度计算,元梯度计算本身又涉及评估 $\nabla_{\phi_i} \mathcal{L}_{\text{test}, i}$(又一次反向传播),以及随后的 HVP 计算。这会使每次外循环更新的计算负担大幅增加。内存占用: 最主要的瓶颈通常是内存。为了计算元梯度,特别是二阶项,自动微分框架需要维护整个内循环优化过程的计算图。这包括正向传播中的激活以及内循环更新中可能使用的中间梯度。如果内循环涉及多个梯度步骤,此图的大小会随着步骤数线性增长。对于像 Transformer 这样的深度网络,每层的激活会消耗相当大的内存。存储应用于数十亿参数模型的多个适应步骤的图,会迅速耗尽标准加速器(GPU/TPU)上的可用内存。下图说明了单步 MAML 元梯度计算中的数据流和依赖关系,强调了需要保留计算图的位置。digraph MAML_Computation { rankdir=LR; node [shape=box, style=rounded, fontname="sans-serif", color="#495057", fontcolor="#495057"]; edge [fontname="sans-serif", color="#868e96", fontcolor="#495057"]; theta [label="元参数 θ", color="#1c7ed6", fontcolor="#1c7ed6"]; support_data [label="支持数据(任务 i)", shape=cylinder, style=filled, fillcolor="#dee2e6"]; query_data [label="查询数据(任务 i)", shape=cylinder, style=filled, fillcolor="#dee2e6"]; subgraph cluster_inner { label = "内循环(任务 i)"; style=filled; color="#e9ecef"; node [shape=ellipse]; edge []; inner_fwd [label="正向传播\nL_train,i(θ)"]; inner_bwd [label="反向传播\n∇θ L_train,i(θ)"]; phi_i [label="适应参数\nϕi = θ - α ∇θ L_train,i"]; theta -> inner_fwd; support_data -> inner_fwd; inner_fwd -> inner_bwd; theta -> inner_bwd; // 梯度计算的隐含依赖 inner_bwd -> phi_i; theta -> phi_i; } subgraph cluster_outer { label = "外循环梯度"; style=filled; color="#e9ecef"; node [shape=ellipse]; edge []; outer_fwd [label="正向传播\nL_test,i(ϕi)"]; outer_bwd1 [label="反向传播\n∇ϕi L_test,i(ϕi)"]; hvp [label="海森向量积\n(∇²θ L_train,i) ⋅ v"]; meta_grad_i [label="元梯度(任务 i)\n∇θ L_meta,i"]; } meta_grad_final [label="最终元梯度\nΣ ∇θ L_meta,i", shape=box, style=rounded, color="#1c7ed6", fontcolor="#1c7ed6"]; phi_i -> outer_fwd [label="使用适应参数"]; query_data -> outer_fwd; outer_fwd -> outer_bwd1; phi_i -> outer_bwd1; // 梯度计算的隐含依赖 // HVP 的连接 outer_bwd1 -> hvp [label="v = ∇ϕi L_test,i", style=dashed]; // 对内循环计算图的海森依赖 inner_fwd -> hvp [label="需要内循环图", style=dashed, constraint=false]; inner_bwd -> hvp [label="需要内循环图", style=dashed, constraint=false]; theta -> hvp [label="关于 θ", style=dashed]; // 组合元梯度项 outer_bwd1 -> meta_grad_i; hvp -> meta_grad_i; meta_grad_i -> meta_grad_final; // 突出显示内存依赖 {rank=same; inner_fwd; inner_bwd;} inner_fwd -> inner_bwd [style=invis]; // 视觉上保持对齐 memory_dep [label="外循环反向传播所需的\n内存中存储的图", shape=note, style=filled, fillcolor="#ffec99", fontcolor="#495057"]; memory_dep -> inner_fwd [style=dotted, color="#f59f00", arrowhead=none]; memory_dep -> inner_bwd [style=dotted, color="#f59f00", arrowhead=none]; memory_dep -> phi_i [style=dotted, color="#f59f00", arrowhead=none]; }用于计算单个任务 MAML 元梯度的计算图依赖关系。内循环计算(正向传播、反向传播、参数更新)必须保留在内存中(橙色虚线),以计算二阶外循环梯度所需的海森向量积 (HVP)。一阶近似:一种权衡认识到这些困难,开发了 MAML 的一阶近似方法,例如 FOMAML 和 Reptile。FOMAML 明确忽略了二阶项,将元梯度近似为:$$ \nabla_\theta \mathcal{L}{\text{meta}} \approx \sum_i \nabla{\phi_i} \mathcal{L}_{\text{test}, i}(\phi_i) $$这种计算只需相对于适应参数 $\phi_i$ 计算测试损失的标准梯度,将 $\phi_i$ 视为一个独立参数而非 $\theta$ 的函数。这完全避免了 HVP 计算,并大幅减少了内存需求,因为内循环的计算图不需要为外循环的反向传播保留。Reptile 通过一种与有限差分相关的不同更新规则达到了类似的效果。尽管计算成本低得多,但一阶方法代表了一种近似。与二阶方法相比,它们可能表现出不同的收敛动态,并可能导致不同的解,尽管在许多实际情况下,特别是在大型模型和适当调整下,它们的表现非常出色。总而言之,元梯度的计算,特别是 MAML 中涉及二阶导数的那些计算,在处理基础模型时带来了主要的计算和内存瓶颈。通过内循环优化过程进行微分的需求,要么需要昂贵的海森向量积,要么以牺牲近似精度为代价打破计算图依赖。理解这些困难是开发策略的第一步,这些策略将在后续讨论,以使元学习在大规模应用中变得可行。