Masterclass
When training large language models, instabilities like sudden loss spikes or divergence (NaN
values) are unfortunately common. As discussed earlier in the course, certain techniques are fundamental not just for effective training, but specifically for preventing and mitigating these issues. Let's revisit three significant stabilization methods: gradient clipping, learning rate scheduling, and learning rate warmup, viewing them through the lens of troubleshooting instability. Proper configuration and monitoring of these components are often the first line of defense when a training run starts to behave unexpectedly.
Gradient clipping directly addresses the problem of exploding gradients, which is a frequent cause of NaN
losses or sudden divergence. By imposing a maximum limit on the magnitude of gradients, it prevents excessively large updates to the model's weights, which can destabilize the training process, particularly in deep networks or recurrent structures inherent in Transformers.
The most common technique is clipping the norm of the gradients across all model parameters. If the L2 norm (Euclidean norm) of the gradients for the entire model (or sometimes per parameter group) exceeds a predefined threshold max_norm
, the gradients are scaled down proportionally to match this threshold.
Here, g represents the vector of all gradients.
In PyTorch, this is typically applied after the backward pass but before the optimizer step:
import torch
from torch.nn.utils import clip_grad_norm_
# Assume model, loss, optimizer are defined
# ... forward pass ...
loss = compute_loss(outputs, targets)
# ... backward pass ...
loss.backward()
# Clip gradients
max_norm = 1.0 # A common starting value, requires tuning
clip_grad_norm_(model.parameters(), max_norm)
# Optimizer step
optimizer.step()
optimizer.zero_grad()
Choosing max_norm
: The value for max_norm
is a hyperparameter. Common values range from 0.5 to 10.0, with 1.0 being a frequent default. Setting it too low can hinder learning by excessively shrinking updates, while setting it too high makes it ineffective against large spikes. Monitoring the gradient norm during stable runs (as discussed in the "Monitoring Training Metrics" section) can provide a baseline for choosing an appropriate value. If you observe norms frequently exceeding your chosen max_norm
during stable periods, you might be clipping too aggressively. Conversely, if loss spikes occur and the norm at those points is far above your max_norm
, clipping is likely helping.
While effective, gradient clipping should not be seen as a cure all. It's a stabilization mechanism. If clipping is constantly active or required at very low thresholds, it might indicate deeper issues with the model architecture, initialization, learning rate, or data quality that should also be investigated.
A fixed learning rate is rarely optimal for training large, complex models. Learning rate schedules dynamically adjust the learning rate during training, balancing exploration early on with finer convergence later. This controlled adjustment is significant for stability. A learning rate that is too high throughout training can easily lead to oscillations or divergence, while one that decays too slowly might keep making large, potentially destabilizing steps even when nearing a minimum.
Common schedules used for LLMs, often applied after an initial warmup phase, include:
Here's how you might combine warmup and cosine decay using PyTorch's schedulers:
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import (
CosineAnnealingLR, LinearLR, SequentialLR
)
# Assume model is defined
optimizer = AdamW(model.parameters(), lr=1e-4) # Base learning rate
# Define warmup and decay phases
warmup_steps = 500
total_training_steps = 10000
decay_steps = total_training_steps - warmup_steps
# Warmup scheduler (linear increase from 0 to base LR)
# Note: LinearLR factor starts at near-zero (1/warmup_steps) and goes to 1.0
warmup_scheduler = LinearLR(
optimizer,
start_factor=1.0/warmup_steps,
end_factor=1.0,
total_iters=warmup_steps
)
# Cosine decay scheduler (from base LR to 0)
decay_scheduler = CosineAnnealingLR(
optimizer, T_max=decay_steps, eta_min=0
) # T_max is the number of steps for one half cycle
# Chain the schedulers
scheduler = SequentialLR(
optimizer,
schedulers=[warmup_scheduler, decay_scheduler],
milestones=[warmup_steps]
)
# Inside training loop:
# ... training steps ...
optimizer.step()
scheduler.step() # Update learning rate
optimizer.zero_grad()
An improperly configured schedule can contribute to instability. For example, if the decay is too slow or the final learning rate (eta_min
in CosineAnnealingLR
) is too high, the optimizer might overshoot or oscillate near the end of training. Conversely, decaying too quickly can stall convergence. When faced with late stage instability, reviewing the learning rate schedule's behavior is a worthwhile diagnostic step.
The very beginning of training is a particularly vulnerable phase. Model weights are typically randomly initialized, and applying the full target learning rate immediately can cause extremely large, chaotic updates, leading to loss explosions or NaN
values within the first few iterations. Adaptive optimizers like AdamW also need some initial steps to build up reliable estimates of gradient moments.
Learning rate warmup addresses this by starting with a very small learning rate (often zero or close to it) and gradually increasing it to the target base learning rate over a predefined number of initial training steps (the warmup_steps
). This gives the model and optimizer time to stabilize before larger updates are applied. Linear warmup is the most common strategy.
A typical learning rate schedule including 500 steps of linear warmup followed by cosine decay to zero over the remaining steps.
Choosing warmup_steps
: The number of warmup steps is another hyperparameter. It often ranges from hundreds to several thousands of steps, depending on the overall training duration, batch size, and dataset. A common practice is to set it to a small percentage (e.g., 1-10%) of the total expected training steps. Too few warmup steps might not prevent initial instability, while too many can slow down the start of effective learning. If you experience instability very early in training, increasing the warmup duration or decreasing the initial learning rate are primary adjustments to consider.
These three techniques are interconnected. For instance:
When diagnosing instabilities, it's important to monitor the behavior related to these techniques:
Revisiting and carefully tuning gradient clipping, learning rate schedules, and warmup phases are essential steps when troubleshooting training instabilities. They provide knobs to control the update dynamics, preventing the explosions and divergences that can halt progress in large scale language model training.
© 2025 ApX Machine Learning