Optimizers play a fundamental role in navigating the complex loss landscapes associated with training deep neural networks. While standard Stochastic Gradient Descent (SGD) laid the groundwork, its fixed learning rate often proves inadequate for the scale and complexity of Transformer models. Adaptive learning rate algorithms, which dynamically adjust the step size for each parameter, have become essential tools. Among these, Adam and its refinement, AdamW, are the most widely adopted optimizers for training Transformers.
Adam, short for Adaptive Moment Estimation, integrates ideas from both momentum methods and RMSprop. It computes individual adaptive learning rates for different parameters based on estimates of the first and second moments of the gradients.
First Moment (Momentum): Adam maintains an exponentially decaying average of past gradients. This acts like momentum, helping the optimizer accelerate through shallow ravines in the loss surface and dampen oscillations in directions of high curvature. The first moment estimate mt at timestep t is calculated using the gradient gt: mt=β1mt−1+(1−β1)gt The hyperparameter β1 controls the exponential decay rate, typically set to a value like 0.9.
Second Moment (Variance Adaptation): Adam also maintains an exponentially decaying average of past squared gradients. This serves as an estimate of the variance (or more accurately, the uncentered second moment) of the gradients for each parameter. This estimate is used to scale the learning rate inversely; parameters with larger or noisier gradient histories receive smaller updates, while those with smaller or sparser gradients receive larger updates. The second moment estimate vt is calculated as: vt=β2vt−1+(1−β2)gt2 The hyperparameter β2 controls this decay rate, often set to 0.999.
Bias Correction: Because the moment estimates mt and vt are initialized at zero, they are biased toward zero during the early stages of training. Adam incorporates bias correction terms to counteract this: m^t=1−β1tmt v^t=1−β2tvt These corrected estimates, m^t and v^t, provide better estimates of the true moments, especially early on.
Parameter Update: Finally, the parameters w are updated using the bias-corrected moment estimates. The update rule scales the base learning rate α by the ratio of the corrected first moment to the square root of the corrected second moment: wt+1=wt−αv^t+ϵm^t The small constant ϵ (e.g., 10−8) is added to the denominator for numerical stability, preventing division by zero.
Adam often demonstrates rapid initial convergence and robust performance across various deep learning tasks. However, a subtle issue arises when combining Adam with standard L2 regularization (weight decay). L2 regularization aims to prevent overfitting by adding a penalty term 2λ∣∣w∣∣2 to the loss function. This results in an additional term −λwt being added to the gradient gt. In the conventional Adam implementation, this gradient component gt=∇L(wt)+λwt is used to compute both mt and vt. This couples the weight decay effect with the adaptive learning rate mechanism. Specifically, the effective weight decay applied during the update (αv^t+ϵλwt) gets scaled by the adaptive denominator v^t. This means weights associated with large historical gradient magnitudes (large v^t) experience diminished effective decay, potentially hindering the regularization effect.
AdamW was introduced to resolve the suboptimal interaction between weight decay and the adaptive learning rates in the original Adam algorithm. The core idea proposed by Loshchilov & Hutter (2019) is simple yet effective: decouple the weight decay update from the gradient-based update performed by Adam.
Instead of incorporating the L2 penalty into the gradient gt used for moment estimation, AdamW calculates the moment estimates and the main adaptive update step using only the gradient derived from the primary loss function, ∇L(wt). The weight decay is then applied separately and directly to the weights after the adaptive step.
The process can be summarized as:
By decoupling, AdamW ensures that weight decay applies more consistently across all parameters, proportional only to the weight magnitude wt and the learning rate ηt, irrespective of the gradient history captured in v^t. This modification often leads to improved model generalization and better final performance compared to standard Adam with L2 regularization, particularly for complex models like Transformers where effective regularization is highly beneficial. Consequently, AdamW has become the de facto standard optimizer for training modern Transformers.
Effectively using Adam or AdamW requires setting several hyperparameters:
In summary, while Adam represented a major step forward in optimization, AdamW provides a more refined approach to incorporating weight decay, leading to demonstrable improvements in training large language models. It is generally the recommended choice for Transformer architectures, always used in conjunction with carefully tuned learning rate schedules and regularization parameters.
© 2025 ApX Machine Learning