Masterclass
Before we examine the adaptive methods prevalent in modern large language model training, it's helpful to revisit the foundational gradient descent algorithms. Understanding their mechanics and limitations provides context for why more sophisticated optimizers like AdamW are often necessary for navigating the complex optimization landscapes of LLMs.
At its core, optimization in deep learning involves adjusting model parameters (weights and biases, denoted collectively as θ) to minimize a loss function L. Gradient Descent achieves this by iteratively moving the parameters in the direction opposite to the gradient of the loss function with respect to the parameters.
Full-batch Gradient Descent calculates the gradient using the entire training dataset, which is computationally infeasible for the massive datasets used in LLM pre-training. Stochastic Gradient Descent (SGD) addresses this by approximating the gradient using only a small, random subset of the data, called a mini-batch, at each step.
The update rule for SGD is:
θ←θ−η∇θL(θ;x(i:i+b),y(i:i+b))Here:
The primary advantage of SGD is its computational efficiency per step. Processing a small mini-batch is significantly faster than processing the entire dataset. The stochastic nature of the updates (due to random mini-batch sampling) also introduces noise, which can sometimes help the optimizer escape poor local minima.
However, this noise can also be a disadvantage. The updates can oscillate significantly, leading to a jagged convergence path. Furthermore, SGD can struggle in landscapes with high curvature or ravines (areas where the surface curves much more steeply in one dimension than another), potentially taking many steps to navigate towards the minimum. SGD is also quite sensitive to the choice of learning rate η.
In PyTorch, using SGD is straightforward:
import torch
# Assume 'model' is your defined neural network
# Assume 'loss_fn' is your loss function
# Assume 'data_loader' provides mini-batches of data
learning_rate = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
# Inside your training loop:
for inputs, targets in data_loader:
optimizer.zero_grad() # Reset gradients from previous step
outputs = model(inputs) # Forward pass
loss = loss_fn(outputs, targets) # Calculate loss
loss.backward() # Backward pass (compute gradients)
optimizer.step() # Update parameters
To mitigate the oscillations inherent in SGD and accelerate convergence, particularly in ravines, the Momentum technique was introduced. It adds a "velocity" term, v, which accumulates an exponentially decaying moving average of past gradients. The parameter update then incorporates this velocity.
The update rules are typically formulated as:
Here:
The velocity term vt helps smooth out the updates. If successive gradients point in similar directions, the velocity builds up, leading to larger steps and faster convergence. If gradients oscillate, the momentum term helps dampen these oscillations by averaging them out. Think of it like a heavy ball rolling down a hill; it maintains momentum in its current direction and is less affected by small bumps (noisy gradients).
While Momentum generally improves convergence speed and stability compared to plain SGD, it still relies on a single learning rate η for all parameters and requires careful tuning of both η and β.
Using Momentum in PyTorch only requires adding the momentum
argument to the SGD
optimizer:
import torch
# Assume 'model' is your defined neural network
learning_rate = 0.01
momentum_beta = 0.9
# Use SGD optimizer with momentum
optimizer = torch.optim.SGD(
model.parameters(),
lr=learning_rate,
momentum=momentum_beta
)
# Training loop remains the same as the SGD example...
# Inside your training loop:
# for inputs, targets in data_loader:
# optimizer.zero_grad()
# outputs = model(inputs)
# loss = loss_fn(outputs, targets)
# loss.backward()
# optimizer.step()
While SGD and Momentum form the basis of many optimization strategies, training large language models often involves navigating extremely high-dimensional parameter spaces with complex loss surfaces. These foundational methods can be slow to converge or get stuck. This motivates the use of adaptive optimization algorithms like Adam and AdamW, which adjust the learning rate per parameter and often lead to faster convergence in practice for these large-scale models. We will explore these adaptive methods in the following sections.
© 2025 ApX Machine Learning