Masterclass
Loss spikes are among the most frustrating occurrences during large model training. A run progressing smoothly for days or weeks can suddenly exhibit a sharp, often vertical, increase in the loss value, sometimes collapsing into NaN
(Not a Number) or inf
(infinity). This typically halts effective training. Diagnosing the root cause requires a systematic investigation, as several factors can trigger such events.
When confronted with a loss spike, the immediate goal is to understand when and why it happened. Monitoring tools (like TensorBoard, Weights & Biases) are indispensable here, allowing you to pinpoint the exact training step where the spike occurred.
The chart above shows a typical loss spike where the loss value abruptly increases before potentially recovering or diverging completely. Note the logarithmic scale often used for loss visualization.
Here’s a breakdown of common causes and how to investigate them:
Often, the trigger is a single "bad" batch of data. This could involve:
Diagnostic Steps:
NaN
values in the source data (if applicable), or characters outside the tokenizer's vocabulary that might be handled poorly.# Example: Checking a batch tensor for NaNs or Infs in PyTorch
# Assume `input_ids` is the tensor for the problematic batch
import torch
def check_tensor_health(tensor: torch.Tensor, name: str):
has_nan = torch.isnan(tensor).any()
has_inf = torch.isinf(tensor).any()
if has_nan or has_inf:
print(f"Problem detected in tensor '{name}':")
if has_nan:
print(f" - Contains NaN values!")
print(f" - NaN count: {torch.isnan(tensor).sum().item()}")
if has_inf:
print(f" - Contains Inf values!")
print(f" - Inf count: {torch.isinf(tensor).sum().item()}")
# Consider logging or printing problematic parts of the tensor
# print(tensor[torch.isnan(tensor) | torch.isinf(tensor)])
return False
return True
# --- Inside your training loop or debugging script ---
# Load or identify the problematic batch_data (e.g., input_ids, attention_mask)
# input_ids = load_problematic_batch(...)
# Check input tensors
# if not check_tensor_health(input_ids, "input_ids"):
# # Handle error or breakpoint
# pass
# After the forward pass, check model outputs
# model_output = model(input_ids)
# loss = calculate_loss(model_output, labels)
# if not check_tensor_health(loss, "Calculated Loss"):
# Investigate why loss became NaN/Inf
# pass
Loss spikes are often mechanistically caused by exploding gradients. Even if the input data seems fine, operations within the model can lead to excessively large numbers.
Diagnostic Steps:
NaN
/inf
: Before the optimizer step (optimizer.step()
), check the loss tensor itself and the gradients of the model parameters for NaN
or inf
values. A NaN
loss is a definitive sign of numerical instability upstream. NaN
gradients will corrupt the weights upon the optimizer step.torch.nn.utils.clip_grad_norm_
not only clip but also return the norm before clipping, which is useful for logging.# Example: Checking gradients before optimizer step in PyTorch
# --- Inside your training loop, after loss.backward() ---
total_norm = 0.0
nan_or_inf_found = False
for p in model.parameters():
if p.grad is not None:
if not check_tensor_health(
p.grad,
f"Gradient of {p.name if hasattr(p, 'name') else 'parameter'}"
):
nan_or_inf_found = True
# Optionally break or log more details about the specific parameter
# break
param_norm = p.grad.detach().data.norm(2)
total_norm += param_norm.item() ** 2
total_norm = total_norm ** 0.5
print(f"Step {global_step}: Total Gradient Norm: {total_norm}")
if nan_or_inf_found:
print(
f"Step {global_step}: NaN or Inf detected in gradients "
f"BEFORE optimizer step. Skipping update."
)
# Potentially skip optimizer.step() for this batch or halt training
# optimizer.zero_grad() # Still need to clear gradients
# continue or raise Exception
elif total_norm > gradient_clipping_threshold * 10:
# Arbitrary large multiplier
print(
f"Warning: Step {global_step}: Gradient norm ({total_norm}) "
f"significantly exceeds clipping threshold "
f"({gradient_clipping_threshold}). Potential instability."
)
# Optional: Gradient Clipping (often done even if no spike, but important here)
# torch.nn.utils.clip_grad_norm_(
# model.parameters(), max_norm=gradient_clipping_threshold
# )
# If no NaN/Inf gradients were found:
# optimizer.step()
# optimizer.zero_grad()
While less likely to cause a single abrupt spike after stable training (more often causes divergence over several steps), the learning rate can play a role.
Diagnostic Steps:
NaN
/inf
). Frameworks usually handle this, but inspecting optimizer.state_dict()
might reveal issues in very obscure cases.Training with FP16
(16-bit floating-point) is particularly prone to numerical range issues. While BF16
(bfloat16) offers a wider range, extreme values can still cause problems.
Diagnostic Steps:
FP16
with automatic mixed precision (AMP), ensure loss scaling is active. Loss spikes can occur if gradients become too large (> 65504, the max value for FP16) before being unscaled by the loss scaler. Check if the loss scale value itself became NaN
or zero, which can happen if gradients underflowed repeatedly before an overflow.FP32
. Check if any custom operations or numerically unstable functions (like certain reductions or normalizations on extreme values) are being performed in lower precision.BF16
is generally more stable than FP16
due to its wider dynamic range, often eliminating the need for loss scaling. Experiencing spikes with FP16 might motivate switching to BF16.Diagnosing loss spikes is an iterative process. By systematically checking the data, monitoring gradients and activations, reviewing the optimizer configuration, and considering the specifics of mixed-precision training, you can usually pinpoint the source of the instability and apply the appropriate mitigation techniques discussed elsewhere in this chapter, such as adjusting the learning rate, improving data cleaning, or refining gradient clipping and loss scaling strategies.
© 2025 ApX Machine Learning