Masterclass
Effective monitoring is the first line of defense against the instabilities that can derail large-scale language model training. Simply launching a multi-day or multi-week training job and hoping for the best is impractical. Instead, engineers need to continuously observe the model's behavior through important metrics. This allows for early detection of problems, facilitating timely intervention before significant computational resources are wasted. Two of the most informative metrics for diagnosing training health are the training loss and the norm of the gradients.
The training loss quantifies how well the model is performing on the training data at any given moment. For language models, this is typically a measure like cross-entropy loss, which reflects the model's uncertainty in predicting the next token.
Expected Behavior: In a healthy training run, the loss should generally decrease over time, often rapidly in the early stages and then more slowly as training progresses, eventually plateauing as the model converges. Minor fluctuations are normal, but the overall trend should be downwards.
Signs of Instability:
NaN
(Not a Number) or infinity. This is a critical failure, usually caused by numerical overflow (e.g., dividing by zero, taking the logarithm of zero or a negative number) or exploding gradients. Training cannot proceed once a NaN occurs.Implementation: Logging the loss is straightforward within a standard training loop. You typically calculate the loss after the forward pass and log its value periodically.
import torch
import torch.nn as nn
# Assume model, data_loader, optimizer are defined
# Example logging setup (can use TensorBoard, WandB, etc.)
def log_metric(step, metric_name, value):
# Replace with your actual logging implementation
print(f"Step {step}: {metric_name} = {value:.4f}")
global_step = 0
for batch in data_loader:
optimizer.zero_grad()
# Assume inputs and targets are obtained from batch
inputs = batch['input_ids'].to('cuda')
targets = batch['labels'].to('cuda')
outputs = model(inputs)
# Assuming loss calculation involves reshaping for cross_entropy
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(outputs.view(-1, model.config.vocab_size), targets.view(-1))
# Log the loss (e.g., every 10 steps)
if global_step % 10 == 0:
log_metric(global_step, "train_loss", loss.item())
loss.backward()
optimizer.step()
global_step += 1
# Check for NaN loss early
if torch.isnan(loss):
print(f"NaN loss detected at step {global_step}! Stopping training.")
break
Visualizing the loss curve over time is essential. Tools like TensorBoard or Weights & Biases make this easy.
Example loss curves showing stable training, a sudden loss spike, and divergence where the loss increases over time. Note the logarithmic y-axis, common for viewing loss.
The gradient is the vector indicating the direction and magnitude of the steepest ascent of the loss function. The norm of this vector, typically the L2 norm (∣∣∇L∣∣2​), measures its magnitude. Monitoring the gradient norm provides insights into the scale of the updates being applied to the model's weights.
∣∣∇L∣∣2​=p∈parameters∑​∣∣∇p​L∣∣22​​Where ∇p​L is the gradient of the loss L with respect to a specific parameter tensor p.
Why it's Important:
Expected Behavior: The gradient norm often starts relatively high and decreases as the model converges and the loss flattens near a minimum. However, its behavior is highly dependent on the learning rate schedule, optimizer, and data. Significant fluctuations can occur, but extremely large values are a red flag.
Signs of Instability:
Implementation: Calculating the total gradient norm requires iterating over all model parameters after the loss.backward()
call but before the optimizer.step()
call. Gradient clipping, a technique discussed later, often involves calculating this norm anyway.
import torch
import torch.nn as nn
from torch.nn.utils import clip_grad_norm_
# Assume model, data_loader, optimizer are defined, log_metric is defined
global_step = 0
# Define a max grad norm value for clipping (discussed in stabilization techniques)
max_grad_norm = 1.0
for batch in data_loader:
optimizer.zero_grad()
# --- Forward pass and loss calculation as before ---
inputs = batch['input_ids'].to('cuda')
targets = batch['labels'].to('cuda')
outputs = model(inputs)
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
outputs.view(-1, model.config.vocab_size),
targets.view(-1)
)
# Log loss
if global_step % 10 == 0:
log_metric(global_step, "train_loss", loss.item())
if torch.isnan(loss):
print(f"NaN loss detected at step {global_step}! Stopping training.")
break
loss.backward()
# --- Calculate and log gradient norm BEFORE optimizer step ---
# Create a generator of parameter gradients
grads = [p.grad for p in model.parameters() if p.grad is not None]
if len(grads) > 0:
# Calculate the L2 norm over all gradients
total_norm = torch.norm(
torch.stack([
torch.norm(g.detach(), 2.0) for g in grads
]),
2.0
)
# Log the gradient norm (e.g., every 10 steps)
if global_step % 10 == 0:
log_metric(global_step, "grad_norm", total_norm.item())
# Optional: Clip gradients (common practice)
# clip_grad_norm_(model.parameters(), max_grad_norm)
# Check for exploding gradients
if total_norm > 100 * max_grad_norm: # Heuristic threshold
print(
f"Warning: High gradient norm ({total_norm:.2f}) "
f"detected at step {global_step}"
)
else:
# Handle case where there are no gradients
# (e.g., if model has no trainable params)
if global_step % 10 == 0:
log_metric(global_step, "grad_norm", 0.0)
optimizer.step()
global_step += 1
Visualizing the gradient norm alongside the loss provides a more complete picture of training dynamics.
Example gradient L2 norm curves showing a stable decrease and an exploding gradient scenario where the norm increases dramatically.
By carefully monitoring both the training loss and the gradient norm, you gain essential visibility into the training process. These metrics act as early warning systems, allowing you to identify potential instabilities before they escalate into critical failures, saving valuable time and computational resources. Anomalies in these curves are the starting point for diagnosing the specific underlying issues, which we will explore next.
Was this section helpful?
© 2025 ApX Machine Learning