趋近智
requires_grad)backward()).grad)torch.nn 搭建模型torch.nn.Module 基类torch.nn 损失)torch.optim)torch.utils.data.Datasettorchvision.transforms)torch.utils.data.DataLoader有效的训练依赖于反向传播期间计算的梯度。这些梯度引导优化器更新模型参数以最小化损失函数。然而,这些梯度的大小有时会成为问题,导致训练不稳定或停滞。两个常见问题是梯度消失和梯度爆炸。了解如何检查梯度是诊断训练中出现的问题的一项重要技能。
在反向传播过程中,梯度使用链式法则逐层计算。在深度网络中,这涉及将许多小数字(导数)相乘。
NaN(非数字),从而有效地停止训练。梯度爆炸可能由于不佳的权重初始化、过高的学习率或某些网络结构引起,尤其是在循环神经网络中。PyTorch 的 Autograd 系统计算梯度,并将其存储在 requires_grad=True 的张量的 .grad 属性中。这些梯度在 loss.backward() 调用后即可访问,并在 optimizer.step() 更新模型参数或 optimizer.zero_grad() 清除梯度之前保持可用。
一个常用做法是监控模型中所有可训练参数的梯度整体大小(范数)。L2 范数(欧几里得范数)是常用的一种。非常小的范数表明梯度消失,而非常大或 NaN 的范数则表明梯度爆炸。
以下是在训练循环中计算并记录总梯度范数的方法:
# 在训练循环中,在 loss.backward() 之后
total_norm = 0
for p in model.parameters():
if p.grad is not None:
param_norm = p.grad.detach().data.norm(2) # 计算此参数梯度的 L2 范数
total_norm += param_norm.item() ** 2 # 平方和
total_norm = total_norm ** 0.5 # 平方和的平方根
print(f"总梯度范数: {total_norm}")
# 通常,你会使用 TensorBoard 或其他日志框架来记录此值
随时间监控此值可以提供信息:
模型梯度总 L2 范数随训练步数变化的趋势,以对数尺度显示。稳定的训练显示相对一致的范数,梯度爆炸显示快速增加(常导致 NaN),梯度消失则显示趋近于零的下降。
有时,梯度问题可能局限于特定层。你可以直接检查单个参数的梯度。
# 在训练循环中,在 loss.backward() 之后
# 示例:检查第一个卷积层的权重梯度
if hasattr(model, 'conv1') and model.conv1.weight.grad is not None:
conv1_grad_mean = model.conv1.weight.grad.abs().mean().item()
conv1_grad_max = model.conv1.weight.grad.abs().max().item()
print(f"层 conv1 - 平均绝对梯度: {conv1_grad_mean:.6f}, 最大绝对梯度: {conv1_grad_max:.6f}")
# 示例:检查特定线性层的偏置梯度
if hasattr(model, 'fc2') and model.fc2.bias.grad is not None:
fc2_bias_grad_norm = model.fc2.bias.grad.norm(2).item()
print(f"层 fc2 (偏置) - L2 范数: {fc2_bias_grad_norm:.6f}")
查看平均或最大绝对梯度值,或特定层的范数,可以帮助确定梯度是在减小还是在不受控制地增长。使用直方图(例如,使用 Matplotlib 或通过 TensorBoard 记录)来可视化某一层的梯度值分布也很有用。
为了进行更详细的调试,PyTorch 提供了钩子。可以在任何 nn.Module 上注册一个反向钩子(register_full_backward_hook)。当为该模块计算了梯度时,此钩子会执行,允许你检查甚至修改通过它的梯度(grad_input,grad_output)。尽管功能强大,但钩子会增加复杂性,通常在简单检查方法不足时使用。
间接来看,训练损失本身就是一个强有力的指示器。
NaN: 几乎总是梯度爆炸或数学上无效操作(如 log(0))的迹象。检测梯度问题是第一步。解决这些问题通常涉及其他地方更详细介绍的技术,但常见策略包括:
torch.nn.utils.clip_grad_norm_ 或 torch.nn.utils.clip_grad_value_ 是标准实用工具。模型工作后,你不一定需要在每次训练运行时都检查梯度,但当训练不稳定或无效时,它是一个必不可少的诊断工具。通过监控梯度范数和检查单个层的梯度,你可以获得关于训练动态的有价值信息,并发现潜在的梯度消失或梯度爆炸问题。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造