损失飙升是大型模型训练中最令人沮丧的情况之一。一个顺利运行了数天或数周的训练,可能会突然出现损失值急剧(通常是垂直)增加的情况,有时甚至崩溃为 NaN(非数值)或 inf(无穷大)。这通常会中止有效的训练。要找出根本原因,需要进行系统性排查,因为有多种因素可能引发此类事件。当面对损失飙升时,首要目标是了解它 何时 以及 为何 发生。监控工具(如TensorBoard、Weights & Biases)在此处不可或缺,它们能帮助你精准定位到发生飙升的具体训练步骤。{"layout": {"title": "训练期间损失飙升示例", "xaxis": {"title": "训练步数"}, "yaxis": {"title": "损失", "type": "log"}}, "data": [{"x": [0, 1000, 2000, 3000, 3001, 4000, 5000], "y": [10.5, 8.2, 6.5, 5.1, 55.0, 4.8, 4.5], "mode": "lines+markers", "name": "训练损失", "line": {"color": "#f03e3e"}}]}上图展示了一个典型的损失飙升,损失值突然增加,随后可能恢复或完全偏离。请注意,损失可视化常使用对数刻度。以下是常见原因及其排查方式:1. 检查有问题的数据批次通常,触发因素是单个“异常”批次数据。这可能包括:损坏数据: 未正确处理的文件,含有无意义字符、过长序列或格式错误。数值异常值: 当数据点通过模型初始层(如嵌入层)处理时,产生异常大的激活值,可能随后导致溢出。异常输入格式: 数据不符合预期的分词或填充方案,可能是由于预处理中的某个边界情况。诊断步骤:隔离批次: 如果你的训练框架允许,记录紧接在损失飙升之前批次中包含的样本的索引或标识符。手动检查: 加载并检查识别出的特定数据样本。查找异常长的序列、重复模式、原始数据中的 NaN 值(如果适用),或分词器词汇表外的可能处理不当的字符。单批次复现: 尝试仅使用怀疑有问题批次进行单次前向和反向传播。这可以确认该批次本身是否单独触发了问题。# 示例:在PyTorch中检查批次张量是否含有NaN或Inf # 假设 `input_ids` 是有问题批次的张量 import torch def check_tensor_health(tensor: torch.Tensor, name: str): has_nan = torch.isnan(tensor).any() has_inf = torch.isinf(tensor).any() if has_nan or has_inf: print(f"Problem detected in tensor '{name}':") if has_nan: print(f" - Contains NaN values!") print(f" - NaN count: {torch.isnan(tensor).sum().item()}") if has_inf: print(f" - Contains Inf values!") print(f" - Inf count: {torch.isinf(tensor).sum().item()}") # 可以考虑记录或打印张量中出现问题的部分 # print(tensor[torch.isnan(tensor) | torch.isinf(tensor)]) return False return True # --- 在你的训练循环或调试脚本中 --- # 加载或识别有问题的批次数据(例如,input_ids, attention_mask) # input_ids = load_problematic_batch(...) # 检查输入张量 # if not check_tensor_health(input_ids, "input_ids"): # # 处理错误或设置断点 # pass # 前向传播后,检查模型输出 # model_output = model(input_ids) # loss = calculate_loss(model_output, labels) # if not check_tensor_health(loss, "Calculated Loss"): # 调查损失为何变为NaN/Inf # pass2. 分析梯度和激活值损失飙升在机制上常由梯度爆炸引起。即使输入数据看起来正常,模型内部的操作也可能导致过大的数值。诊断步骤:检查 NaN/inf: 在优化器步骤 (optimizer.step()) 之前,检查损失张量本身以及模型参数的梯度中是否含有 NaN 或 inf 值。 NaN 损失是上游数值不稳定的明确信号。 NaN 梯度会在优化器步骤时损坏权重。监控梯度范数: 如“监控训练指标”章节所述,追踪总梯度范数(例如,所有梯度拼接后的L2范数)。梯度范数的突然大幅增长常常在损失飙升之前发生或同时发生。这表明参数更新即将变得过大。 torch.nn.utils.clip_grad_norm_ 等工具不仅进行裁剪,还会返回裁剪 前 的范数,这对于记录有帮助。激活值统计: 高级调试可能涉及在模型中设置钩子,以记录不同层中激活值的统计数据(最小值、最大值、均值、标准差)。激活值中的极端数值可能是梯度问题的先兆。# 示例:在PyTorch中优化器步骤前检查梯度 # --- 在你的训练循环中,在 loss.backward() 之后 --- total_norm = 0.0 nan_or_inf_found = False for p in model.parameters(): if p.grad is not None: if not check_tensor_health( p.grad, f"Gradient of {p.name if hasattr(p, 'name') else 'parameter'}" ): nan_or_inf_found = True # 可选:中断或记录关于特定参数的更多详情 # break param_norm = p.grad.detach().data.norm(2) total_norm += param_norm.item() ** 2 total_norm = total_norm ** 0.5 print(f"Step {global_step}: Total Gradient Norm: {total_norm}") if nan_or_inf_found: print( f"Step {global_step}: NaN or Inf detected in gradients " f"BEFORE optimizer step. Skipping update." ) # 对于这个批次,可能跳过 optimizer.step() 或停止训练 # optimizer.zero_grad() # 仍然需要清零梯度 # continue or raise Exception elif total_norm > gradient_clipping_threshold * 10: # 任意大的乘数 print( f"Warning: Step {global_step}: Gradient norm ({total_norm}) " f"significantly exceeds clipping threshold " f"({gradient_clipping_threshold}). Potential instability." ) # 可选:梯度裁剪(即使没有飙升也常进行,但在此处很重要) # torch.nn.utils.clip_grad_norm_( # model.parameters(), max_norm=gradient_clipping_threshold # ) # 如果未找到 NaN/Inf 梯度: # optimizer.step() # optimizer.zero_grad()3. 检查学习率和优化器状态尽管在稳定训练后不太可能导致 单次 突然飙升(更常导致数个步骤内的偏离),但学习率可能会起到作用。诊断步骤:检查当前学习率: 记录每一步的学习率值。调度器是否表现异常?是否存在导致学习率重置或跳变的错误?优化器状态: 极其罕见地,自适应优化器(如Adam的动量或方差估计)的内部状态可能会损坏(NaN/inf)。框架通常会处理这种情况,但在非常罕见的情况下,检查 optimizer.state_dict() 可能会发现问题。4. 调查混合精度问题使用 FP16 (16位浮点) 训练特别容易出现数值范围问题。尽管 BF16 (bfloat16) 提供更宽的范围,极端值仍可能导致问题。诊断步骤:损失缩放: 如果使用 FP16 和自动混合精度 (AMP),请确保损失缩放处于启用状态。如果梯度在被损失缩放器反向缩放 之前 变得过大(> 65504,FP16的最大值),则可能发生损失飙升。检查损失缩放值本身是否变为 NaN 或零,这可能发生在梯度在溢出前反复下溢的情况下。在 FP32 中进行的操作: 某些操作可能对数值敏感,显式转换为 FP32 会有益。检查是否有任何自定义操作或数值不稳定的函数(如对极端值进行某些归约或归一化)以较低精度执行。尝试 BF16: 如果硬件支持,BF16 通常比 FP16 更稳定,因为它具有更宽的动态范围,通常无需损失缩放。在使用FP16时遇到飙升可能促使切换到BF16。诊断损失飙升是一个迭代过程。通过系统性地检查数据、监控梯度和激活值、审视优化器配置以及考虑混合精度训练的细节,你通常可以找到不稳定性的来源,并应用本章其他地方讨论的相应缓解技术,例如调整学习率、改进数据清洗或优化梯度裁剪和损失缩放策略。