趋近智
损失飙升是大型模型训练中最令人沮丧的情况之一。一个顺利运行了数天或数周的训练,可能会突然出现损失值急剧(通常是垂直)增加的情况,有时甚至崩溃为 NaN(非数值)或 inf(无穷大)。这通常会中止有效的训练。要找出根本原因,需要进行系统性排查,因为有多种因素可能引发此类事件。
当面对损失飙升时,首要目标是了解它 何时 以及 为何 发生。监控工具(如TensorBoard、Weights & Biases)在此处不可或缺,它们能帮助你精准定位到发生飙升的具体训练步骤。
上图展示了一个典型的损失飙升,损失值突然增加,随后可能恢复或完全偏离。请注意,损失可视化常使用对数刻度。
以下是常见原因及其排查方式:
通常,触发因素是单个“异常”批次数据。这可能包括:
诊断步骤:
NaN 值(如果适用),或分词器 (tokenizer)词汇表 (vocabulary)外的可能处理不当的字符。# 示例:在PyTorch中检查批次张量是否含有NaN或Inf
# 假设 `input_ids` 是有问题批次的张量
import torch
def check_tensor_health(tensor: torch.Tensor, name: str):
has_nan = torch.isnan(tensor).any()
has_inf = torch.isinf(tensor).any()
if has_nan or has_inf:
print(f"Problem detected in tensor '{name}':")
if has_nan:
print(f" - Contains NaN values!")
print(f" - NaN count: {torch.isnan(tensor).sum().item()}")
if has_inf:
print(f" - Contains Inf values!")
print(f" - Inf count: {torch.isinf(tensor).sum().item()}")
# 可以考虑记录或打印张量中出现问题的部分
# print(tensor[torch.isnan(tensor) | torch.isinf(tensor)])
return False
return True
# --- 在你的训练循环或调试脚本中 ---
# 加载或识别有问题的批次数据(例如,input_ids, attention_mask)
# input_ids = load_problematic_batch(...)
# 检查输入张量
# if not check_tensor_health(input_ids, "input_ids"):
# # 处理错误或设置断点
# pass
# 前向传播后,检查模型输出
# model_output = model(input_ids)
# loss = calculate_loss(model_output, labels)
# if not check_tensor_health(loss, "Calculated Loss"):
# 调查损失为何变为NaN/Inf
# pass
损失飙升在机制上常由梯度爆炸引起。即使输入数据看起来正常,模型内部的操作也可能导致过大的数值。
诊断步骤:
NaN/inf: 在优化器步骤 (optimizer.step()) 之前,检查损失张量本身以及模型参数 (parameter)的梯度中是否含有 NaN 或 inf 值。 NaN 损失是上游数值不稳定的明确信号。 NaN 梯度会在优化器步骤时损坏权重 (weight)。torch.nn.utils.clip_grad_norm_ 等工具不仅进行裁剪,还会返回裁剪 前 的范数,这对于记录有帮助。# 示例:在PyTorch中优化器步骤前检查梯度
# --- 在你的训练循环中,在 loss.backward() 之后 ---
total_norm = 0.0
nan_or_inf_found = False
for p in model.parameters():
if p.grad is not None:
if not check_tensor_health(
p.grad,
f"Gradient of {p.name if hasattr(p, 'name') else 'parameter'}"
):
nan_or_inf_found = True
# 可选:中断或记录关于特定参数的更多详情
# break
param_norm = p.grad.detach().data.norm(2)
total_norm += param_norm.item() ** 2
total_norm = total_norm ** 0.5
print(f"Step {global_step}: Total Gradient Norm: {total_norm}")
if nan_or_inf_found:
print(
f"Step {global_step}: NaN or Inf detected in gradients "
f"BEFORE optimizer step. Skipping update."
)
# 对于这个批次,可能跳过 optimizer.step() 或停止训练
# optimizer.zero_grad() # 仍然需要清零梯度
# continue or raise Exception
elif total_norm > gradient_clipping_threshold * 10:
# 任意大的乘数
print(
f"Warning: Step {global_step}: Gradient norm ({total_norm}) "
f"significantly exceeds clipping threshold "
f"({gradient_clipping_threshold}). Potential instability."
)
# 可选:梯度裁剪(即使没有飙升也常进行,但在此处很重要)
# torch.nn.utils.clip_grad_norm_(
# model.parameters(), max_norm=gradient_clipping_threshold
# )
# 如果未找到 NaN/Inf 梯度:
# optimizer.step()
# optimizer.zero_grad()
尽管在稳定训练后不太可能导致 单次 突然飙升(更常导致数个步骤内的偏离),但学习率可能会起到作用。
诊断步骤:
NaN/inf)。框架通常会处理这种情况,但在非常罕见的情况下,检查 optimizer.state_dict() 可能会发现问题。使用 FP16 (16位浮点) 训练特别容易出现数值范围问题。尽管 BF16 (bfloat16) 提供更宽的范围,极端值仍可能导致问题。
诊断步骤:
FP16 和自动混合精度 (AMP),请确保损失缩放处于启用状态。如果梯度在被损失缩放器反向缩放 之前 变得过大(> 65504,FP16的最大值),则可能发生损失飙升。检查损失缩放值本身是否变为 NaN 或零,这可能发生在梯度在溢出前反复下溢的情况下。FP32 中进行的操作: 某些操作可能对数值敏感,显式转换为 FP32 会有益。检查是否有任何自定义操作或数值不稳定的函数(如对极端值进行某些归约或归一化 (normalization))以较低精度执行。BF16: 如果硬件支持,BF16 通常比 FP16 更稳定,因为它具有更宽的动态范围,通常无需损失缩放。在使用FP16时遇到飙升可能促使切换到BF16。诊断损失飙升是一个迭代过程。通过系统性地检查数据、监控梯度和激活值、审视优化器配置以及考虑混合精度训练的细节,你通常可以找到不稳定性的来源,并应用本章其他地方讨论的相应缓解技术,例如调整学习率、改进数据清洗或优化梯度裁剪和损失缩放策略。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造