While optimizers like SGD with momentum and Adam serve as effective workhorses for many deep learning tasks, achieving optimal performance, especially on complex models or challenging datasets, often benefits from more refined optimization strategies. Standard optimizers sometimes exhibit undesirable behaviors, such as suboptimal handling of weight decay or instability during the initial phases of training. This section introduces several advanced optimizers available in or commonly used with PyTorch, designed to address these specific limitations.
Adam remains a popular and generally effective adaptive learning rate optimizer. However, its standard implementation often conflates L2 regularization with true weight decay.
Recall that L2 regularization adds a penalty term to the loss function based on the squared magnitude of the weights: Ltotal=Loriginal+2λ∣∣w∣∣2 When computing the gradient, this adds a term proportional to the weight itself (λw) to the gradient of the original loss: ∇wLtotal=∇wLoriginal+λw In optimizers like Adam, this gradient term λw is then adapted by the optimizer's internal mechanisms (like the moving averages of gradients and squared gradients).
Weight decay, as originally proposed, is a different concept. It involves directly subtracting a small fraction of the weight from itself during the update step, after the gradient computation: wt+1=wt−η(∇wLoriginal+λwt) Or more accurately, for adaptive methods, the decay is often applied separately: wt+1=wt−η⋅AdaptedGradient(∇wLoriginal)−ηλ′wt where λ′ is the weight decay factor.
The key difference is that true weight decay is not scaled by the adaptive learning rates computed by Adam. In the standard Adam implementation that uses L2 regularization, the effective decay applied to weights with historically large gradients can be much smaller than intended, while weights with small gradients might decay too quickly.
AdamW explicitly implements the original concept of weight decay, decoupling it from the gradient adaptation mechanism. This often leads to improved generalization performance compared to Adam with L2 regularization, particularly for models sensitive to regularization strength.
Using AdamW in PyTorch is straightforward as it's included in torch.optim
:
import torch
from torch import nn
from torch import optim
# Assume 'model' is your nn.Module instance
# Example usage of AdamW
optimizer = optim.AdamW(
model.parameters(),
lr=1e-4, # Learning rate
betas=(0.9, 0.999), # Coefficients for moving averages
eps=1e-8, # Term added to the denominator for numerical stability
weight_decay=1e-2, # Weight decay coefficient (applied correctly)
amsgrad=False # Whether to use the AMSGrad variant
)
# Typical training loop step
# optimizer.zero_grad()
# loss.backward()
# optimizer.step()
For many applications, switching from optim.Adam
to optim.AdamW
and tuning the weight_decay
parameter can provide a noticeable improvement with minimal code change. It's often recommended as a default starting point over standard Adam.
Lookahead is not an optimizer itself, but rather a mechanism that wraps around an existing optimizer (like AdamW or SGD). It aims to improve learning stability and reduce the variance of the parameter updates by maintaining two sets of weights: "fast" weights and "slow" weights.
The inner, base optimizer (e.g., AdamW) updates the fast weights for k steps. After these k steps, the slow weights are updated by moving them in the direction of the final fast weights from that sequence. The fast weights are then reset to the position of the new slow weights, and the process repeats.
Let wslow be the slow weights and wfast be the fast weights. Let O be the inner optimizer (e.g., AdamW).
la_alpha
).The intuition is that the inner optimizer explores the parameter space rapidly with the fast weights, while the slow weights provide a more stable, averaged trajectory, reducing the risk of oscillating around or overshooting optima. Lookahead often leads to faster convergence and better final performance.
Lookahead is not part of the standard torch.optim
package but can be implemented relatively easily or found in external libraries like torchcontrib
or potentially integrated into frameworks like PyTorch Lightning.
A conceptual implementation might look like this (simplified):
# Conceptual Lookahead usage (requires a Lookahead implementation)
# from some_library import Lookahead # Hypothetical import
# Define base optimizer
base_optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)
# Wrap with Lookahead
optimizer = Lookahead(base_optimizer, la_steps=5, la_alpha=0.5)
# Training loop remains the same:
# optimizer.zero_grad()
# loss.backward()
# optimizer.step() # Lookahead internally manages fast/slow weights and k steps
Common values for k (la_steps
) are 5 or 10, and for α (la_alpha
) is 0.5 or 0.8.
Adam's adaptive learning rate, based on estimates of the first and second moments of the gradients, is powerful but can suffer from high variance in the early stages of training. When the number of samples (mini-batches) seen is small, the estimate of the second moment (vt) can be unreliable. This can lead to excessively large or small learning rates initially, potentially hindering convergence or causing divergence.
Rectified Adam (RAdam) addresses this issue by introducing a rectification term that adjusts the adaptive learning rate based on the variance of the second moment estimate. Essentially, it measures the variance of the adaptive learning rate term (m^t/(v^t+ϵ)). If the variance is estimated to be high (typically early in training), RAdam temporarily turns off the adaptive learning rate, effectively operating like SGD with momentum. As more data is processed and the variance estimate becomes more reliable (decreases), the adaptive learning rate mechanism is gradually introduced.
This "warmup" behavior for the adaptive component helps stabilize training from the start, making RAdam less sensitive to the choice of initial learning rate compared to standard Adam.
Similar to Lookahead, RAdam might not be in the core torch.optim
but is available in several popular extension libraries.
# Conceptual RAdam usage (requires an RAdam implementation)
# from some_library import RAdam # Hypothetical import
# Example usage of RAdam
optimizer = RAdam(
model.parameters(),
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0 # RAdam typically used without separate weight decay initially
)
# Training loop remains the same:
# optimizer.zero_grad()
# loss.backward()
# optimizer.step()
RAdam can be particularly beneficial when training models known to be sensitive to initialization or learning rate choices, or when dealing with smaller batch sizes where variance is inherently higher.
Experimentation remains important. While these optimizers address specific theoretical shortcomings of simpler methods, their practical impact varies depending on the model architecture, dataset, and other hyperparameters. Profiling training dynamics and validation performance is essential when selecting and tuning these more sophisticated optimizers.
© 2025 ApX Machine Learning