Masterclass
While standard Stochastic Gradient Descent (SGD) and its momentum variant form the bedrock of optimization, training large, complex models like LLMs often benefits from adaptive learning rate methods. These algorithms adjust the learning rate for each parameter individually, potentially leading to faster convergence, especially in settings with sparse gradients or varying gradient magnitudes across parameters, which are common in deep neural networks.
One of the most popular adaptive optimizers is Adam (Adaptive Moment Estimation). Adam computes individual adaptive learning rates for different parameters from estimates of first and second moments of the gradients. It essentially combines the ideas of Momentum (using the first moment estimate, an exponentially decaying average of past gradients) and RMSprop (using the second moment estimate, an exponentially decaying average of past squared gradients).
Let gt​ be the gradient of the objective function with respect to the parameters θ at timestep t. Adam maintains two moving averages:
First Moment Estimate (Momentum):
mt​=β1​mt−1​+(1−β1​)gt​This is an estimate of the mean of the gradients. β1​ is the exponential decay rate, typically close to 1 (e.g., 0.9).
Second Moment Estimate (Uncentered Variance):
vt​=β2​vt−1​+(1−β2​)gt2​This is an estimate of the uncentered variance of the gradients (element-wise square). β2​ is the exponential decay rate, also typically close to 1 (e.g., 0.999).
Since mt​ and vt​ are initialized as vectors of zeros, they are biased towards zero, especially during the initial timesteps. Adam corrects for this bias:
m^t​=1−β1t​mt​​ v^t​=1−β2t​vt​​where t is the current timestep index (starting from 1).
Finally, the parameter update rule is:
θt​=θt−1​−ηv^t​​+ϵm^t​​Here, η is the base learning rate, and ϵ is a small constant (e.g., 10−8) added for numerical stability, primarily to prevent division by zero. The term v^t​​+ϵη​ acts as an effective, parameter-specific learning rate. Parameters with larger past gradients (larger v^t​) receive smaller updates, while parameters with smaller past gradients receive larger updates.
In PyTorch, using Adam is straightforward:
import torch
import torch.optim as optim
# Assume model is a defined torch.nn.Module
# learning_rate, beta1, beta2, epsilon are hyperparameters
optimizer = optim.Adam(
model.parameters(),
lr=learning_rate,
betas=(beta1, beta2),
eps=epsilon
)
# Inside the training loop:
# loss = compute_loss(...)
# optimizer.zero_grad()
# loss.backward()
# optimizer.step()
While Adam works well in many situations, its handling of L2 regularization (weight decay) can be suboptimal. Standard L2 regularization adds a term 2λ​∥θ∥2 to the loss function, resulting in a gradient term λθ being added to gt​. In Adam, this weight decay term λθ becomes part of the adaptive learning rate calculation through mt​ and vt​. This means the effective weight decay applied to a parameter depends on the historical magnitude of its gradients (via v^t​​). Parameters with large gradients experience smaller effective weight decay than intended, while parameters with small gradients experience larger effective weight decay.
AdamW proposes a simple fix: decouple the weight decay from the gradient update. Instead of adding λθ to the gradient gt​, AdamW performs the standard Adam update using only the gradient from the primary loss function and then applies the weight decay directly to the parameters after the Adam step.
The AdamW update rule looks like this:
Notice the final term: −ηλθt−1​. The weight decay is applied directly to the previous weight value θt−1​ and is scaled only by the global learning rate η, not the adaptive rate. This makes the weight decay behave more like it does in standard SGD with momentum, leading to better generalization performance in many cases, particularly for deep models like Transformers.
Using AdamW in PyTorch is similar to Adam, just requiring the weight_decay
parameter:
import torch
import torch.optim as optim
# Assume model is a defined torch.nn.Module
# learning_rate, beta1, beta2, epsilon, weight_decay_lambda are hyperparameters
optimizer = optim.AdamW(
model.parameters(),
lr=learning_rate,
betas=(beta1, beta2),
eps=epsilon,
weight_decay=weight_decay_lambda # Note the weight_decay parameter
)
# Inside the training loop:
# loss = compute_loss(...)
# optimizer.zero_grad()
# loss.backward()
# optimizer.step()
Due to its improved handling of weight decay and strong empirical performance, AdamW has become a very common choice for training large language models. The choice between Adam and AdamW, along with the setting of their hyperparameters (η,β1​,β2​,ϵ,λ), often depends on the specific model architecture, dataset, and training setup, requiring careful tuning discussed later in this chapter.
© 2025 ApX Machine Learning