Masterclass
Mixed-precision training, leveraging lower-precision formats like FP16 (half-precision) or BF16 (bfloat16), is a standard technique for accelerating LLM training and reducing memory footprints. However, these formats have a significantly narrower representable range compared to the default FP32 (single-precision). This reduced range can lead to numerical instability, manifesting as NaN
(Not a Number) or Inf
(Infinity) values in activations, gradients, or the loss itself, derailing the training process. Debugging these issues requires careful inspection and understanding of where precision limitations might strike.
Recall from Chapter 20 that FP16 has a limited dynamic range. Very small numbers (like small gradients) can become zero (underflow), effectively stopping learning for those parameters. Conversely, large numbers (activations or gradients) can exceed the maximum representable value, becoming Inf
(overflow). NaN
values typically arise from mathematically undefined operations like 0/0, −1​, or ∞−∞, which can occur if intermediate computations involve Inf
.
BF16 was designed with deep learning in mind, offering the same dynamic range as FP32 but with reduced precision (fewer mantissa bits). This significantly mitigates the overflow problem compared to FP16, often eliminating the need for techniques like loss scaling. However, its lower precision can still sometimes cause issues in operations sensitive to small numerical differences, although this is less common than FP16 underflow/overflow.
Numerical precision problems often surface abruptly. A training run proceeding smoothly might suddenly encounter NaN
loss or gradients. Important indicators include:
NaN
or Inf
, backpropagation cannot proceed correctly.NaN
or Inf
. This prevents optimizer updates for those parameters and can quickly corrupt the model if not handled.NaN
/Inf
values, indicating overflow.Once you suspect a numerical precision issue, the goal is to isolate the specific operation or module where the instability originates. A NaN
or Inf
generated in one layer can quickly propagate through subsequent computations.
The simplest first step is often to disable mixed-precision training and run entirely in FP32. If the instability disappears, it strongly implicates the lower-precision format. If the instability persists in FP32, the root cause might be a different issue, like a bug in the model code, bad data, or unsuitable hyperparameters (e.g., an excessively high learning rate).
PyTorch's hook mechanism provides a powerful way to inspect intermediate tensors (activations) during the forward pass and gradients during the backward pass without fundamentally altering your model's structure. You can register hooks on specific modules or tensors to check for NaN
/Inf
values immediately after they are computed.
Here's an example of registering a forward hook to check activations for NaN
or Inf
values after a specific linear layer:
import torch
import torch.nn as nn
def check_nan_inf_hook(module, input, output):
"""Forward hook to check module output for NaNs/Infs."""
if isinstance(output, torch.Tensor):
if torch.isnan(output).any() or torch.isinf(output).any():
print(f"NaN/Inf detected in output of module: {module}")
# Optionally, raise an error or enter debugger
# import pdb; pdb.set_trace()
elif isinstance(output, tuple): # Handle modules returning tuples
for i, out in enumerate(output):
if isinstance(out, torch.Tensor):
if torch.isnan(out).any() or torch.isinf(out).any():
print(
f"NaN/Inf detected in output tuple element {i} "
f"of module: {module}"
)
# import pdb; pdb.set_trace()
# Assuming 'model' is your LLM instance
# Register the hook on a specific layer,
# e.g., the first FFN layer in block 5
target_layer = model.transformer.h[5].mlp.c_fc
handle = target_layer.register_forward_hook(check_nan_inf_hook)
# --- Run your training iteration ---
# output = model(input_ids)
# loss = criterion(output, targets)
# loss.backward()
# ---
# Remember to remove the hook when done debugging
handle.remove()
Similarly, you can register backward hooks (register_full_backward_hook
for modules or register_hook
for specific tensors) to inspect gradients (grad_input
, grad_output
). By strategically placing these hooks, you can narrow down the computation step where instability first appears.
After loss.backward()
, you can iterate through model parameters and check their .grad
attribute:
import torch
def check_gradients(model):
"""Check all model parameters for NaN/Inf gradients."""
nan_inf_found = False
for name, param in model.named_parameters():
if param.grad is not None:
if torch.isnan(param.grad).any():
print(f"NaN detected in gradient of parameter: {name}")
nan_inf_found = True
if torch.isinf(param.grad).any():
print(f"Inf detected in gradient of parameter: {name}")
nan_inf_found = True
if not nan_inf_found:
print("No NaN/Inf gradients detected.")
return nan_inf_found
# After loss.backward() and before optimizer.step()
# check_gradients(model)
# Optionally, clip gradients before checking
# torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# check_gradients(model)
If using FP16 with dynamic loss scaling (e.g., via PyTorch's torch.cuda.amp.GradScaler
), the scaler itself might provide clues. An overflowing gradient during the backward pass will cause the scaler to skip the optimizer step and decrease the loss scale for subsequent iterations. If this happens repeatedly, the loss scale might become extremely small or even zero, leading to gradient underflow. Conversely, if the loss scale grows very large without issue, it might increase the likelihood of intermediate overflows later.
You can inspect the GradScaler
's current scale factor:
import torch
# Assuming 'scaler' is your torch.cuda.amp.GradScaler instance
current_loss_scale = scaler.get_scale()
print(f"Current loss scale: {current_loss_scale}")
# Check if the scaler skipped the last optimizer step
# This requires modifying your training loop slightly to capture the state
# Before scaler.update():
# inf_detected = scaler._check_inf_per_device(optimizer)
# After scaler.update():
# if inf_detected:
# print("Optimizer step skipped due to Inf/NaN gradients.")
Monitoring the current_loss_scale
over time can reveal problematic patterns. A scale that repeatedly crashes suggests persistent overflow issues. A scale that drops to very low values might indicate subsequent underflow problems.
Certain mathematical operations are inherently more prone to producing NaN
or Inf
, especially with limited precision:
torch.log(x)
where x≤0. This often occurs with probabilities or softmax outputs that might numerically evaluate to zero. Using torch.log_softmax
is generally more stable than torch.log(torch.softmax(x))
.torch.sqrt(x)
where x<0.torch.exp(x)
for large x can easily overflow.When debugging, pay close attention to computations involving these operations, especially within custom layers or loss functions. Adding small epsilon values (1e−8 to 1e−6) inside square roots or denominators can sometimes prevent NaN
s caused by near-zero intermediate values, but be mindful that this slightly changes the computation.
Debugging often involves applying or tuning the stabilization techniques discussed previously:
Inf
. If instability occurs, consider if the clipping threshold is appropriate.log_softmax
). Add epsilons carefully where needed.NaN
s during stable training, poor initialization (Chapter 12) can lead to large activations early on, potentially increasing overflow risk.Debugging numerical precision issues in large-scale training requires patience and systematic investigation. By monitoring important metrics, leveraging tools like hooks, and understanding the limitations of lower-precision formats, you can effectively diagnose and resolve these instabilities, keeping your lengthy training runs on track.
Was this section helpful?
© 2025 ApX Machine Learning