Masterclass
Recognizing the early signs of training instability is fundamental for preventing wasted compute cycles and ensuring the successful development of large language models. While robust monitoring is discussed in the next section, understanding the common symptoms allows you to quickly identify when a training run is deviating from a healthy trajectory. These symptoms often manifest in the primary metrics you track during training, particularly the loss function and gradient statistics.
Perhaps the most definitive sign of catastrophic failure is the appearance of NaN
(Not a Number) values in your loss calculation.
A NaN
loss typically halts the training process immediately, as subsequent gradient calculations and weight updates become mathematically undefined. This usually signifies a severe numerical issue, such as:
log(0)
), calculating the square root of a negative number, or dividing by zero. These can occur due to specific data points or unstable intermediate activation values.Detecting NaN
s early in the training loop is important. You can add checks directly after the loss computation.
# Inside the training loop (PyTorch example)
outputs = model(inputs)
loss = loss_function(outputs, targets)
# Check for NaN or infinity in the loss
if not torch.isfinite(loss):
print(f"Unstable loss detected: {loss.item()}. Stopping training.")
# Add logic here to save state, log details, and terminate
break
# Proceed with backward pass only if loss is valid
loss.backward()
# ... optimizer step, etc.
Less immediately fatal than NaN
loss, but still a serious warning sign, are sudden, sharp increases in the loss value, often referred to as "loss spikes." The loss might recover partially or fully in subsequent steps, or the spike might precede complete divergence.
A typical loss curve showing a sudden spike around iteration 500 before partially recovering.
Loss spikes can be triggered by several factors:
NaN
, gradients might momentarily explode, causing large, destabilizing weight updates. This is particularly relevant when using adaptive optimizers like Adam.While a single, isolated spike might not derail training entirely, frequent spikes indicate underlying instability that needs addressing.
Unlike a temporary spike, diverging loss refers to a consistent upward trend in the loss value over multiple iterations or epochs. This indicates that the model's performance is continuously degrading, and the optimization process is moving away from, rather than towards, a good solution.
Comparison of a healthy, converging loss curve and a diverging loss curve indicating training failure.
Divergence often points to more fundamental issues:
Another symptom is when the loss value, or other evaluation metrics like perplexity or accuracy on a validation set, fluctuates significantly without showing a clear trend of improvement. The values might bounce up and down between steps or epochs, never settling into a stable decrease (for loss) or increase (for accuracy).
This oscillation often suggests:
While not always directly visible as primary symptoms like NaN loss or spikes, extreme gradient magnitudes are often the root cause.
NaN
s or loss spikes. Monitoring the gradient norm is essential for detecting this.Recognizing these common symptoms is the important first step. The following sections will provide guidance on how to monitor training metrics effectively to catch these signs early and diagnose the underlying causes of instability.
© 2025 ApX Machine Learning