The substantial memory footprint associated with meta-learning on foundation models stems primarily from two sources: storing intermediate activations for gradient calculation during backpropagation, and maintaining optimizer states for potentially billions of parameters. When dealing with meta-gradients, particularly the second-order derivatives required by methods like MAML, these demands multiply, quickly overwhelming available hardware memory. Fortunately, several optimization techniques can significantly alleviate these memory pressures, making large-scale meta-learning practical.
Standard backpropagation requires storing all intermediate activations computed during the forward pass to calculate gradients during the backward pass. For deep networks like foundation models, and especially within the unrolled computation graph of meta-learning (spanning multiple inner loop steps), the memory needed for these activations becomes a major bottleneck.
Gradient checkpointing, also known as activation recomputation, offers a direct trade-off: reduce memory consumption at the cost of increased computation time. Instead of storing all activations, checkpointing strategically selects certain activations to save (the "checkpoints") while discarding others. During the backward pass, when a discarded activation is needed for a gradient calculation, the portion of the model between the previous checkpoint and the required activation is recomputed on-the-fly.
How it Works: Imagine the forward pass as a sequence of functions f1,f2,...,fn. Without checkpointing, all intermediate outputs x1=f1(x0),x2=f2(x1),...,xn=fn(xn−1) are stored. With checkpointing, we might only store x0 and xn/2. To compute the gradient at fn/2+1, we re-run the forward pass from xn/2 to compute xn/2+1, use it for the gradient, and then discard it again.
Comparison of standard backpropagation memory usage versus gradient checkpointing. Checkpointing avoids storing intermediate activations (like c1, c3) by recomputing them during the backward pass (red dashed lines).
Relevance to Meta-Learning: Gradient checkpointing can be applied within the inner loop updates (task adaptation) and/or across the outer loop meta-update. It is particularly effective for MAML and its variants where the computational graph includes the inner optimization process. By checkpointing segments of the inner loop or layers within the foundation model, the peak memory usage during the calculation of potentially complex meta-gradients (like ∇θLmeta) can be substantially reduced. The trade-off is a roughly 20-30% increase in training time, depending on the checkpointing strategy and model architecture, but it often enables training configurations that would otherwise be impossible due to memory limits.
Another powerful technique is mixed-precision training. Instead of performing all computations and storing all values (weights, activations, gradients) in standard 32-bit floating-point precision (FP32), mixed-precision training utilizes lower-precision formats like 16-bit floating-point (FP16) or Brain Floating Point (BF16) for many operations.
Benefits:
Challenges and Solutions: FP16 has a limited dynamic range compared to FP32, making it susceptible to numerical underflow (gradients becoming zero) or overflow (gradients becoming infinity) during training, especially with large models. BF16 offers a wider dynamic range (similar to FP32) but less precision than FP16, often providing a better balance for training large transformers.
The primary technique to counteract underflow/overflow in FP16 is loss scaling. The loss value is multiplied by a scaling factor before backpropagation begins. This scales up the gradients, shifting them into the representable range of FP16. Before the optimizer updates the weights, the gradients are scaled back down by the same factor. This scaling factor can be fixed (static loss scaling) or adjusted dynamically during training (dynamic loss scaling) to find an optimal value that avoids overflow while minimizing underflow.
Relevance to Meta-Learning: Mixed-precision can be applied throughout the meta-learning process. Both the inner-loop adaptation steps and the outer-loop meta-update can leverage lower-precision computations and storage.
Using BF16 is often preferred for large language models due to its better stability during training, although FP16 with effective loss scaling can also work and might offer slightly faster computation on some hardware.
Standard optimizers like Adam or AdamW maintain additional state alongside the model parameters. Adam, for instance, stores estimates of the first moment (momentum) and the second moment (variance) for each parameter. For a model with N parameters stored in FP32, the optimizer states typically require an additional 2×N×4 bytes, effectively tripling the memory needed just for the parameters and their optimization state. When N is in the billions, this becomes a significant memory burden, particularly for the meta-optimizer which operates on the entire foundation model's parameters θ.
Memory-efficient optimizers reduce this burden:
Adafactor: Proposed initially for large NLP models, Adafactor avoids storing the full second moment estimate for each parameter. Instead, it maintains only row and column sums of the squared gradients for weight matrices, effectively storing a factored representation of the second moment tensor. This reduces the memory footprint for the second moment state from O(N) to O(N) for large layers, offering substantial savings without significantly compromising performance. It often doesn't store the first moment either, relying on momentum decay factors.
8-bit Optimizers: Libraries like bitsandbytes
implement optimizers (e.g., 8-bit Adam) that quantize the optimizer states (momentum and variance) to 8-bit integers. Instead of storing 32-bit floats for each state entry, they store an 8-bit integer plus quantization statistics (like block-wise scaling factors) needed to dequantize the state just before the parameter update is performed. This can reduce the optimizer state memory by roughly 75% (e.g., from 8 bytes per parameter for FP32 states down to approximately 2 bytes per parameter).
Other Approaches (e.g., Sophia, Lion): Research continues to produce optimizers aiming for Adam-like performance with reduced memory or computational overhead, though their widespread adoption and robustness for meta-learning large models are still evolving.
Relevance to Meta-Learning: While the inner loop optimizer might operate on a smaller set of parameters (e.g., if using PEFT methods like LoRA during adaptation), the meta-optimizer updating the base foundation model parameters θ operates on the full parameter set. Applying Adafactor or 8-bit Adam to this outer loop optimization step drastically cuts down the memory required for optimizer states, freeing up critical resources.
Relative memory contributions per parameter for different optimization strategies. "Activations" reflects peak usage, which gradient checkpointing (used in "Combined") reduces significantly despite activations being stored in FP32 in the "Efficient Opt" only case. Mixed precision halves activation and gradient memory. 8-bit optimizers drastically reduce optimizer state memory. Combining techniques yields the largest savings.
By strategically combining gradient checkpointing, mixed-precision training, and memory-efficient optimizers, it becomes feasible to manage the memory demands of meta-learning even on massive foundation models, enabling the application of these powerful adaptation techniques at scale.
© 2025 ApX Machine Learning