尽管模型无关元学习 (MAML) 提供了一个有用的框架,可学习适应性初始化,但它对二阶导数(或通过整个内循环优化路径进行反向传播)的依赖带来了很大的计算和内存难题,特别是对于大型基础模型。计算甚至近似海森矩阵,或者存储大量内梯度步骤的计算图,很快就会变得难以承受。
隐式MAML (iMAML) 提供了一种替代方法,它巧妙地避免了这些困难,借助了隐式微分的能力。iMAML不是通过内循环优化器的步骤进行微分,而是通过内循环旨在满足的最优条件进行微分。
核心思想:微分最优条件
回顾一下,MAML 中内循环的目标是从元参数 θ 开始,找到特定于任务的参数 θi′。对于具有支持集损失 Ltaski 的任务 i,这通常通过梯度下降来完成:
θi,k+1′=θi,k′−α∇θ′Ltaski(θi,k′)
其中 θi,0′=θ。经过 K 步后,我们得到 θi′=θi,K′。MAML 通过展开这个过程并使用链式法则计算元梯度 ∇θLmeta(θi′),这涉及海森矩阵 ∇2Ltaski。
iMAML 采用不同视角。它假设内循环优化收敛(或近似收敛)到某个满足最优条件的点 θi′。这个条件的常见选择是调整后参数处的任务损失梯度为零(或接近零):
∇θ′Ltaski(θi′)≈0
这个方程隐式地将调整后的参数 θi′ 定义为初始参数 θ 的函数。隐函数定理 (IFT) 提供了一种计算这个隐式定义函数的导数 ∂θ∂θi′ 的方法,无需直接对优化步骤进行微分。
应用隐函数定理
我们来定义一个基于内循环优化的函数 G(θ,θi′)。一种方法是使用内循环目标的最优条件。为简化起见,我们假设内循环从 θ 开始最小化 Ltaski(ϕ),从而得到 θi′。最优条件是 ∇ϕLtaski(ϕ)∣ϕ=θi′=0。我们可以将其看作一个方程 G(θ,θi′)=∇θ′Ltaski(θi′)=0,假设 θi′ 是通过优化过程由 θ 隐式确定的。
或者,在实践中更常见的是,特别是在使用固定数量的梯度步骤时,我们可以根据梯度下降更新本身的定点方程定义 G。如果 θi′ 是从 θ 开始的 K 步 SGD(学习率为 α)的结果,我们可以考虑单步的定点方程(或相关条件)。为了理解核心机制,我们仍使用最优条件 ∇θ′Ltaski(θi′)=0。
元目标是最小化查询集上的损失 Lmeta(θi′),并在任务上平均。元梯度包含项 ∇θLmeta(θi′)。使用链式法则:
∇θLmeta(θi′)=(∂θ∂θi′)T∇θ′Lmeta(θi′)
挑战在于计算雅可比矩阵 ∂θ∂θi′。在 G(θ,θi′)=∇θ′Ltaski(θi′)=0 上使用 IFT,我们有:
∂θ∂G+∂θi′∂G∂θ∂θi′=0
重新排列得到:
∂θ∂θi′=−(∂θi′∂G)−1∂θ∂G
代入 G(θ,θi′)=∇θ′Ltaski(θi′),我们得到:
∂θi′∂G=∇θ′2Ltaski(θi′)(海森矩阵!)
∂θ∂G=0(如果 θi′ 仅通过初始化依赖于 θ。需要仔细处理。)
这种使用精确最优条件的特定表述对于典型的基于梯度下降的内循环并不完全正确,因为 θi′ 确实依赖于 θ。更实际的表述考虑了更新规则的定点,或直接将 IFT 应用于更新序列。
我们来考虑一个在实践中更常用的直接应用。我们想计算向量-雅可比积 vT∂θ∂θi′,其中 v=∇θ′Lmeta(θi′)。iMAML 在不显式形成雅可比或海森矩阵的情况下找到这个积。它使用的事实是,这个积通常可以通过求解涉及海森矩阵 ∇θ′2Ltaski(θi′) 的线性系统来找到。设 H=∇θ′2Ltaski(θi′)。所需项可以通过求解 Hz=v 形式的方程来近似或计算 z,并且元梯度与 z 有关。
主要观察点是,我们不需要完整的海森矩阵 H。我们只需要计算海森-向量积 (Hv),这可以使用有限差分或自动微分有效完成(类似于计算 Pearlmutter 的 R{.} 运算符),而无需实例化完整的海森矩阵。这个海森-向量积正是共轭梯度 (CG) 算法等迭代方法求解线性系统 Hz=v 所需的。
iMAML 算法概述
- 对于每个元任务批次:
- 对于每个任务 i:
- 初始化任务参数:θi,0′=θ。
- 内循环: 执行 K 步梯度下降以计算支持集损失 Ltaski,以获得调整后的参数 θi′=θi,K′。
- 计算查询集梯度:vi=∇θ′Lmeta(θi′)。
- 隐式梯度计算: 使用共轭梯度(近似地)求解线性系统,以找到与 vi 相关的隐式元梯度贡献。这涉及计算海森-向量积,其中 Hi=∇θ′2Ltaski(θi′),但不是 Hi 本身。所求解的精确系统取决于具体的 iMAML 变体和推导(例如,与 Hiz=vi 或 (I+αHi)z=vi 有关)。设结果为 gimplicit,i。
- 存储 gimplicit,i。
- 元更新: 汇总隐式梯度并更新元参数 θ:
θ←θ−βN1i=1∑Ngimplicit,i
(其中 β 是元学习率)。
MAML(通过展开的优化步骤进行显式反向传播)与 iMAML(通过求解与内循环最优值相关的线性系统进行隐式微分)的梯度计算路径比较。
优点与权衡
优点:
- 内存效率: 这是主要优点。iMAML 避免存储内循环优化的计算图,使其内存占用基本不依赖于内循环步骤的数量 K。这对于调整基础模型非常有益,因为即使是单一步骤的计算图也可能很大。
- 计算成本: 虽然使用 CG 求解线性系统会增加计算量,但它可能比计算完整的二阶 MAML 梯度明显更快,特别是对于大的 K。它避免显式形成或存储海森矩阵。
- 潜在的稳定性: 通过关注定点或最优值,iMAML 可能避免与通过内循环中潜在不稳定的优化动态进行微分相关的状况,尤其是在步骤很多的情况下。
缺点:
- 近似质量: iMAML 的准确性取决于定点假设的有效性和线性系统求解器的精度(例如,CG 迭代次数)。如果内循环收敛不佳或 CG 提前终止,则得到的梯度可能不准确。
- 求解器复杂性: 实现和调整迭代求解器(如 CG)与标准自动微分相比增加了复杂性。确保 CG 的收敛有时可能需要仔细的预处理或参数调整。
- 海森-向量积成本: 尽管比计算完整的海森矩阵便宜,但计算海森-向量积仍需谨慎并带来计算成本(大致相当于两次反向传播)。
在基础模型中的情境
iMAML 提供的显著内存节省使其成为使用基础模型进行元学习的有吸引力的选择。标准 MAML 常常由于需要通过具有数十亿参数模型的内部更新进行反向传播所需的内存而变得不可行。尽管像 FOMAML 这样的零阶方法也节省内存,但 iMAML 尝试隐式保留一些二阶信息,可能带来更好的适应性能。然而,海森-向量积的计算成本和 CG 求解器的复杂性在扩展到最大模型时仍是实际考虑因素。将 iMAML 与其他技术(如混合精度训练或模型并行化)结合,对于在超大规模环境中的实际运用可能是必要的。