Masterclass
Recognizing the early signs of training instability is fundamental for preventing wasted compute cycles and ensuring the successful development of large language models. Understanding the common symptoms allows you to quickly identify when a training run is deviating from a healthy trajectory. These symptoms often manifest in the primary metrics tracked during training, particularly the loss function and gradient statistics, and identifying them is a primary part of effective training monitoring.
Perhaps the most definitive sign of catastrophic failure is the appearance of NaN (Not a Number) values in your loss calculation.
A NaN loss typically halts the training process immediately, as subsequent gradient calculations and weight updates become mathematically undefined. This usually indicates a severe numerical issue, such as:
log(0)), calculating the square root of a negative number, or dividing by zero. These can occur due to specific data points or unstable intermediate activation values.Detecting NaNs early in the training loop is important. You can add checks directly after the loss computation.
# Inside the training loop (PyTorch example)
outputs = model(inputs)
loss = loss_function(outputs, targets)
# Check for NaN or infinity in the loss
if not torch.isfinite(loss):
print(f"Unstable loss detected: {loss.item()}. Stopping training.")
# Add logic here to save state, log details, and terminate
break
# Proceed with backward pass only if loss is valid
loss.backward()
# ... optimizer step, etc.
Less immediately fatal than NaN loss, but still a serious warning sign, are sudden, sharp increases in the loss value, often referred to as "loss spikes." The loss might recover partially or fully in subsequent steps, or the spike might precede complete divergence.
A typical loss curve showing a sudden spike around iteration 500 before partially recovering.
Loss spikes can be triggered by several factors:
NaN, gradients might momentarily explode, causing large, destabilizing weight updates. This is particularly relevant when using adaptive optimizers like Adam.While a single, isolated spike might not derail training entirely, frequent spikes indicate underlying instability that needs addressing.
Unlike a temporary spike, diverging loss refers to a consistent upward trend in the loss value over multiple iterations or epochs. This indicates that the model's performance is continuously degrading, and the optimization process is moving away from, rather than towards, a good solution.
Comparison of a healthy, converging loss curve and a diverging loss curve indicating training failure.
Divergence often points to more fundamental issues:
Another symptom is when the loss value, or other evaluation metrics like perplexity or accuracy on a validation set, fluctuates significantly without showing a clear trend of improvement. The values might bounce up and down between steps or epochs, never settling into a stable decrease (for loss) or increase (for accuracy).
This oscillation often suggests:
While not always directly visible as primary symptoms like NaN loss or spikes, extreme gradient magnitudes are often the root cause.
NaNs or loss spikes. Monitoring the gradient norm is essential for detecting this.Recognizing these common symptoms is the important first step. The following sections will provide guidance on how to monitor training metrics effectively to catch these signs early and diagnose the underlying causes of instability.
Was this section helpful?
© 2026 ApX Machine LearningEngineered with