模型无关元学习 (MAML) 是一种梯度元学习方法中十分重要的算法。其主要观点是找到一组初始模型参数 θ,使其非常适合快速调整。MAML 并非学习在所有任务上平均表现良好的参数,而是学习那些在新任务上只需少量数据和梯度更新即可在该任务上获得良好表现的参数。
数学表述
我们将其形式化。设任务分布为 p(T)。在元训练期间,我们抽样一个批次的任务 {Ti}i=1B。每个任务 Ti 关联着一个损失函数 LTi,通常包含一个用于适应的支持集 DTisupp 和一个用于评估调整后参数的查询集 DTiquery。
核心思想涉及一个两阶段优化过程:
-
内循环(任务特定调整): 对于每个任务 Ti,从共享的初始参数 θ 开始,我们使用任务的支持集 DTisupp 进行一次或多次梯度下降步骤。对于学习率为 α 的单次梯度步骤:
θi′=θ−α∇θLTi(θ,DTisupp)
这些调整后的参数 θi′ 特定于任务 Ti。
-
外循环(元优化): 目标是更新初始参数 θ,以最小化调整后跨任务的预期损失。元目标函数是使用调整后的参数 θi′ 在其各自查询集 DTiquery 上计算的损失之和(或平均值):
θminTi∼p(T)∑LTi(θi′,DTiquery)
代入 θi′ 的表达式(来自内循环),目标变为:
θminTi∼p(T)∑LTi(θ−α∇θLTi(θ,DTisupp),DTiquery)
元参数 θ 基于此元目标使用梯度下降进行更新,通常使用元学习率 β:
θ←θ−β∇θTi∼p(T)∑LTi(θi′,DTiquery)
元梯度计算
计算元梯度 ∇θ∑TiLTi(θi′,DTiquery) 是 MAML 最复杂的部分。由于 θi′ 通过梯度更新步骤依赖于 θ,应用链式法则需要微分内循环的梯度。
我们考虑一个任务 T 并简化表示:Lsupp(θ)=LT(θ,DTsupp) 和 Lquery(θ′)=LT(θ′,DTquery)。调整后的参数为 θ′=θ−α∇θLsupp(θ)。此任务的元梯度为:
∇θLquery(θ′)=∇θLquery(θ−α∇θLsupp(θ))
应用链式法则得出:
∇θLquery(θ′)=∇θ′Lquery(θ′)⋅∇θ(θ−α∇θLsupp(θ))
∇θLquery(θ′)=∇θ′Lquery(θ′)⋅(I−α∇θ2Lsupp(θ))
此处,∇θ′Lquery(θ′) 是查询损失对 调整后 参数 θ′ 的梯度,在 θ′ 处计算。项 ∇θ2Lsupp(θ) 是支持集损失对 初始 参数 θ 的 Hessian 矩阵。
此计算需要计算梯度 ∇θ′Lquery(θ′)、Hessian ∇θ2Lsupp(θ),将 Hessian 乘以 α 和梯度向量,并进行矩阵减法和乘法。这涉及二阶导数,使标准 MAML 在计算上要求较高。
MAML 算法伪代码
以下是 MAML 算法的简化表示:
算法:MAML
要求: 任务分布 p(T)
要求: 步长 α,β
- 随机初始化 θ
- 当 未收敛 时
- \hspace{0.5cm} 抽样一批任务 Ti∼p(T),其中 i=1,…,B
- \hspace{0.5cm} 对于所有 Ti 执行
- \hspace{1cm} 使用支持集评估 ∇θLTi(θ,DTisupp)
- \hspace{1cm} 计算调整后的参数 θi′=θ−α∇θLTi(θ,DTisupp) (内循环更新)
- \hspace{0.5cm} 结束循环
- \hspace{0.5cm} 更新 θ←θ−β∇θ∑i=1BLTi(θi′,DTiquery) (外循环更新,需要使用查询集反向传播通过步骤 6)
- 结束当循环
- 返回 θ
计算考量
MAML 的主要计算难题在于外循环更新(步骤 8),具体来说是涉及二阶导数的元梯度计算。
-
Hessian 计算/Hessian-向量积: 对于具有数百万或数十亿参数的深度学习模型,显式构建 Hessian 矩阵 ∇θ2Lsupp(θ) 在计算上是不可行的,因为其大小为 d×d,其中 d 是参数数量。值得注意的是,元梯度计算仅需要 Hessian 与向量 (∇θ′Lquery(θ′)) 的 乘积。这种 Hessian-向量积 (HVP) 通常可以高效计算,无需构建完整的 Hessian,一般使用有限差分或自动微分技术(例如,Pearlmutter 的技巧,涉及第二次反向传播)。然而,即使是计算 HVP 也会比标准一阶梯度计算增加可观的计算开销。
-
内存占用: 使用自动微分框架的标准实现需要存储内循环更新的计算图,以执行外循环梯度的反向传播。此图包含中间激活和梯度,显著增加内存需求,特别是在使用多个内循环步骤或处理大型基础模型时。
-
计算图: 整体计算涉及内循环(每个任务)的前向传播和反向传播,随后是使用调整后的参数在查询集上的前向传播,最后是元梯度计算的反向传播,其本身涉及与内循环梯度相关的计算。这种嵌套结构增加了整体计算成本。
元参数 (θ)、任务调整参数 (θ'_i)、支持/查询损失以及 MAML 中的梯度流之间的关系。外循环根据调整后在查询集上的表现优化 θ,需要通过内循环的梯度更新步骤(红色箭头)进行反向传播,这涉及二阶导数。
这些计算要求促进了像一阶 MAML (FOMAML) 这样的近似方法和像隐式 MAML (iMAML) 这样的其他方法的出现,我们接下来将考量这些方法。对 MAML 精确机制和成本的理解,为评估这些更具扩展性的变体提供了必要的支撑。