元学习可以构建为双层优化问题。外层循环旨在找到最优的元参数 θ(例如,模型初始化、学习率),以最小化在多个任务上平均的元目标 Lmeta。内层循环通过最小化任务特定的损失 Ltask 来寻找任务参数 ϕ∗,这些参数可能从 θ 开始或受其引导。具体来说:
θminET∼p(T)[Lmeta(ϕ∗(θ,T))]
满足 ϕ∗(θ,T)=argϕminLtask(ϕ;θ,DTtr)
这里,T 表示从分布 p(T) 中采样的一个任务,DTtr 是任务 T 的支持集,且 Lmeta 通常在查询集 DTqry 上评估。主要难点在于计算外层目标相对于元参数 θ 的梯度,这需要通过确定 ϕ∗ 的内层优化过程。存在一些算法策略来解决这种依赖。
通过内层循环展开的梯度下降
最直接的方法,例如 MAML 等算法,涉及将内层循环的优化过程视为外层目标计算图的一部分。如果内层循环使用 K 步梯度下降来寻找近似解 ϕK,从 ϕ0=f(θ) 开始(其中 f 可能是恒等函数或将元参数映射到初始任务参数的某个函数):
ϕk+1=ϕk−α∇ϕLtask(ϕk;θ,DTtr)对于 k=0,…,K−1
我们可以使用链式法则计算外层梯度 ∇θLmeta(ϕK),通过内层优化的所有 K 步进行反向传播。这实质上“展开”了内层循环。
原理:
梯度计算涉及诸如 ∂θ∂ϕK 的项。通过 K 步重复应用链式法则得到:
∇θLmeta(ϕK)=∇ϕKLmeta⋅∂θ∂ϕK
其中 ∂θ∂ϕK 取决于在每一步 k=0,…,K−1 上 Ltask 相对于 ϕ 和 θ 的梯度。如果 ϕ0=θ,则依赖是直接的。如果 θ 影响 Ltask 本身(例如,超参数适应),则会出现额外的项。
难点:
- 计算成本: 经过 K 步优化进行反向传播可能计算密集,特别是当需要 Ltask 的二阶导数时(如在精确 MAML 中)。成本大致与 K 呈线性关系。
- 内存占用: 反向传播需要存储 K 个内层步骤的中间激活和梯度,导致大量的内存消耗,对大型基础模型尤其不利。
- 梯度消失/爆炸: 对于大的 K,通过展开的优化路径传播的梯度可能出现消失或爆炸问题,类似于训练深度循环网络。
像 FOMAML 这样的一阶近似方法通过在反向传播期间忽略二阶导数项来降低成本,大幅减少计算量,但可能影响性能。Reptile 通过在任务上重复进行 SGD 步骤来近似元梯度。
通过内层循环展开计算梯度。元梯度 ∇θLmeta 需要通过产生任务参数 ϕK 的一系列内层优化步骤进行反向传播。
隐式微分方法
另一种方法避免了显式展开内层循环。隐式微分基于内层循环收敛到满足某个最优条件的点 ϕ∗ 的假设,通常是任务损失的梯度为零:
∇ϕLtask(ϕ∗(θ);θ)=0
假设此条件成立,我们可以对其进行相对于 θ 的隐式微分。应用链式法则得到:
dθd[∇ϕLtask(ϕ∗(θ);θ)]=0
∇ϕϕ2Ltask⋅∂θ∂ϕ∗+∇ϕθ2Ltask=0
这里,∇ϕϕ2Ltask 是内层目标相对于 ϕ 的 Hessian 矩阵,而 ∇ϕθ2Ltask 是混合偏导数,两者都在 (ϕ∗(θ),θ) 处评估。我们可以重新排列以找到雅可比矩阵 ∂θ∂ϕ∗:
∂θ∂ϕ∗=−(∇ϕϕ2Ltask)−1∇ϕθ2Ltask
随后可以使用链式法则计算外层梯度 ∇θLmeta(ϕ∗):
∇θLmeta(ϕ∗)=∇ϕ∗Lmeta⋅∂θ∂ϕ∗=−∇ϕ∗Lmeta(∇ϕϕ2Ltask)−1∇ϕθ2Ltask
原理:
重要的是,这种方法避免了显式形成或求逆可能庞大的 Hessian 矩阵 ∇ϕϕ2Ltask。相反,计算涉及求解线性系统或计算 Hessian-向量积 (HVPs)。例如,计算最终梯度首先涉及计算向量 v=∇ϕ∗Lmeta,然后求解线性系统:
(∇ϕϕ2Ltask)z=∇ϕθ2Ltask(求解矩阵 z=∂θ∂ϕ∗ 按列)
或者直接计算所需的乘积:
g=vT(∇ϕϕ2Ltask)−1∇ϕθ2Ltask(计算涉及逆的 HVP)
诸如共轭梯度算法等高效方法可以迭代地求解线性系统或计算逆 Hessian-向量积 (∇ϕϕ2Ltask)−1vT,仅需能够计算任意向量 u 的 Hessian-向量积 ∇ϕϕ2Ltask⋅u。这通常可以通过自动微分高效完成,而无需形成完整的 Hessian 矩阵。诸如隐式 MAML (iMAML) 等算法应用了此技术。
优势:
- 内存效率: 不需要存储内层循环的中间激活,这使其可能更适合大型模型和长内层优化周期。内存成本大致与内层步骤数 K 无关。
- 稳定性: 对于大的 K,它可能比展开更稳定,避免通过展开步骤导致的梯度爆炸/消失。
难点:
- 内层循环收敛性: 基于内层循环近似收敛到驻点的假设。如果内层优化提前停止或收敛不佳,性能可能会下降。
- Hessian 逆计算: 求解线性系统或计算逆 HVP 仍然可能计算密集,尽管通常比带有二阶导数的完全展开更快。Hessian 的条件数影响共轭梯度等迭代求解器的收敛速度。
- 实现复杂性: 需要仔细实现 Hessian-向量积计算和迭代线性求解器。
通过隐式微分计算梯度。这种方法不展开内层循环,而是使用内层循环的最优条件(∇ϕLtask=0)和隐函数定理 (IFT) 来计算 θ 和 ϕ∗ 之间的关系,从而通常通过 Hessian-向量积 (HVPs) 和线性求解器计算 ∇θLmeta。
方法比较
展开与隐式微分之间的选择涉及权衡:
- 展开(例如,MAML,FOMAML):
- 使用标准自动微分框架实现更简单。
- 不需要内层循环收敛,即使在少量步骤后也适用。
- 可能内存密集且计算成本高(特别是二阶)。
- 对于许多内层步骤,易受梯度问题影响。
- 隐式微分(例如,iMAML):
- 内存效率更高,随内层步骤数扩展性更好。
- 对于长适应周期,梯度可能更稳定。
- 需要内层循环近似一个驻点。
- 涉及求解线性系统(HVP 计算),这可能自身带有计算成本和稳定性问题(例如,Hessian 的条件数)。
对于内存是主要限制的大型基础模型,且适应可能涉及许多有效步骤(即使是隐式的),隐式微分方法提供了一种有吸引力的替代方案。然而,实际性能在很大程度上取决于内层优化问题的具体情况以及 HVP 计算和线性求解器的效率。混合方法或进一步的近似方法也是活跃的研究方向,旨在结合两种模式的优势。