Model-Agnostic Meta-Learning (MAML) stands as a prominent gradient-based meta-learning algorithm. Its central premise, as introduced in the chapter overview, is to find a set of initial model parameters θ that are highly amenable to rapid adaptation. Instead of learning parameters that perform well on average across tasks, MAML learns parameters that require only a few gradient updates on a small amount of data from a new task to achieve good performance on that specific task.
Let's formalize this. Consider a distribution of tasks p(T). During meta-training, we sample a batch of tasks {Ti}i=1B. Each task Ti is associated with a loss function LTi and typically has a support set DTisupp for adaptation and a query set DTiquery for evaluating the adapted parameters.
The core idea involves a two-step optimization process:
Inner Loop (Task-Specific Adaptation): For each task Ti, starting from the shared initial parameters θ, we perform one or more gradient descent steps using the task's support set DTisupp. For a single gradient step with learning rate α:
θi′=θ−α∇θLTi(θ,DTisupp)These adapted parameters θi′ are specific to task Ti.
Outer Loop (Meta-Optimization): The goal is to update the initial parameters θ to minimize the expected loss across tasks after adaptation. The meta-objective function is the sum (or average) of the losses computed using the adapted parameters θi′ on their respective query sets DTiquery:
θminTi∼p(T)∑LTi(θi′,DTiquery)Substituting the expression for θi′ (from the inner loop), the objective becomes:
θminTi∼p(T)∑LTi(θ−α∇θLTi(θ,DTisupp),DTiquery)The meta-parameters θ are updated using gradient descent based on this meta-objective, typically using a meta-learning rate β:
θ←θ−β∇θTi∼p(T)∑LTi(θi′,DTiquery)Calculating the meta-gradient ∇θ∑TiLTi(θi′,DTiquery) is the most intricate part of MAML. Since θi′ depends on θ through a gradient update step, applying the chain rule involves differentiating the inner loop's gradient.
Let's consider a single task T and simplify notation: Lsupp(θ)=LT(θ,DTsupp) and Lquery(θ′)=LT(θ′,DTquery). The adapted parameters are θ′=θ−α∇θLsupp(θ). The meta-gradient for this task is:
∇θLquery(θ′)=∇θLquery(θ−α∇θLsupp(θ))Applying the chain rule yields:
∇θLquery(θ′)=∇θ′Lquery(θ′)⋅∇θ(θ−α∇θLsupp(θ)) ∇θLquery(θ′)=∇θ′Lquery(θ′)⋅(I−α∇θ2Lsupp(θ))Here, ∇θ′Lquery(θ′) is the gradient of the query loss with respect to the adapted parameters θ′, evaluated at θ′. The term ∇θ2Lsupp(θ) is the Hessian matrix of the support set loss with respect to the initial parameters θ.
This calculation requires computing the gradient ∇θ′Lquery(θ′), the Hessian ∇θ2Lsupp(θ), multiplying the Hessian by α and the gradient vector, and performing a matrix subtraction and multiplication. This involves second-order derivatives, making standard MAML computationally demanding.
Here is a simplified representation of the MAML algorithm:
Algorithm: MAML
Require: Distribution over tasks p(T) Require: Step sizes α,β
The primary computational challenge in MAML lies in the outer loop update (step 8), specifically the calculation of the meta-gradient involving second-order derivatives.
Hessian Computation/Hessian-Vector Products: Explicitly forming the Hessian matrix ∇θ2Lsupp(θ) is computationally infeasible for deep learning models with millions or billions of parameters, as its size is d×d, where d is the number of parameters. Fortunately, the meta-gradient calculation only requires the product of the Hessian and a vector (∇θ′Lquery(θ′)). This Hessian-vector product (HVP) can often be computed efficiently without forming the full Hessian, typically using finite differences or automatic differentiation techniques (e.g., Pearlmutter's trick involving a second backward pass). However, even computing HVPs adds significant computational overhead compared to standard first-order gradient calculations.
Memory Usage: Standard implementations using automatic differentiation frameworks require storing the computation graph of the inner loop update(s) to perform the backward pass for the outer loop gradient. This graph includes intermediate activations and gradients, substantially increasing memory requirements, especially when multiple inner loop steps are used or when dealing with large foundation models.
Computational Graph: The overall computation involves a forward pass and backward pass for the inner loop (per task), followed by a forward pass using the adapted parameters on the query set, and finally a backward pass for the meta-gradient calculation which itself involves computation related to the inner loop's gradient. This nested structure contributes to the overall computational cost.
Relationship between meta-parameters (θ), task-adapted parameters (θ'_i), support/query losses, and gradient flow in MAML. The outer loop optimizes θ based on query set performance after adaptation, requiring backpropagation through the inner loop's gradient update step (red arrow), which involves second-order derivatives.
These computational demands motivate the development of approximations like First-Order MAML (FOMAML) and alternative approaches like Implicit MAML (iMAML), which we will examine next. Understanding the exact mechanism and cost of MAML provides the necessary foundation for appreciating these more scalable variants.
© 2025 ApX Machine Learning