趋近智
大师班
为确保训练在中断后能够恢复,仅保存模型参数是不够的。一个完整的训练状态包含多个组成部分,它们记录了训练中断时的确切位置。未能保存并恢复这些组成部分的任何一项,都可能导致收敛不佳、结果难以复现,或训练过程无法正确继续。让我们看看需要记录的必要状态组成。
模型参数是检查点的一个主要部分。这些参数,通常称为权重和偏置,它们定义了神经网络的学习功能。访问这些参数的标准方式在 PyTorch 中是通过 state_dict()。这个字典将每个层或缓冲区名称映射到其对应的张量。保存模型的 state_dict 能够保证在恢复时,模型从中断前已达到的确切学习表示开始。
# 假设 'model' 是你的 PyTorch nn.Module 实例
model_state = model.state_dict()
# 例子:保存模型状态
# torch.save(model_state, 'model_checkpoint.pt')
# 例子:加载模型状态
# loaded_state = torch.load('model_checkpoint.pt')
# model.load_state_dict(loaded_state)
对于大型模型,state_dict 本身可能非常大,可能达到数百 GB 甚至数 TB。处理这些大文件需要仔细考量存储和 I/O 效率,尤其是在分布式设置中,我们稍后会讨论这些。
现代优化器,特别是像 Adam 或 AdamW 这样常用于训练大型语言模型的自适应优化器,它们不仅维护超参数(如学习率或权重衰减),还维护内部状态。例如,Adam 会为每个参数维护梯度的第一动量(均值)和第二动量(未中心化方差)的估计值。
mtvt=β1mt−1+(1−β1)gt=β2vt−1+(1−β2)gt2在这里,mt 和 vt 代表参数在时间步 t 的移动平均值,它们基于梯度 gt 和衰减因子 β1 与 β2。这些动量估计值与训练路径上的特定点对应。如果只恢复模型权重而重新初始化优化器,这些历史梯度统计信息就会丢失。优化器会实际从头开始,这会明显扰乱训练动态,可能减缓收敛速度,或使模型收敛到一个不同且可能更差的局部最小值。因此,保存优化器状态对于顺利恢复训练是必不可少的。
与模型类似,PyTorch 优化器也提供了 state_dict() 方法。
# 假设 'optimizer' 是你的 PyTorch 优化器实例
# (例如,torch.optim.AdamW)
optimizer_state = optimizer.state_dict()
# 例子:保存优化器状态
# torch.save(optimizer_state, 'optimizer_checkpoint.pt')
# 例子:加载优化器状态
# loaded_state = torch.load('optimizer_checkpoint.pt')
# optimizer.load_state_dict(loaded_state)
优化器状态不仅包括动量估计(对于 Adam 等优化器),还包括内部步数计数以及可能由优化器实例管理的其他超参数。
大型语言模型训练几乎普遍采用学习率调度策略。常见的策略包括学习率逐渐增加的热身阶段,以及随后的衰减阶段(例如,线性、余弦或多项式衰减)。这些调度策略对于训练稳定性和获得良好性能非常重要。
调度器的行为取决于当前的训练进度,通常以步数或 epoch 数衡量。为正确恢复学习率调度,必须保存其内部状态。这可能包括已进行的步数、上次计算的学习率,或特定调度器逻辑使用的其他内部计数器。
# 假设 'scheduler' 是你的 PyTorch 学习率调度器实例
# (例如,torch.optim.lr_scheduler.LambdaLR)
# 确保在训练期间适当地调度步进
# (例如,scheduler.step())
scheduler_state = scheduler.state_dict()
# 例子:保存调度器状态
# torch.save(scheduler_state, 'scheduler_checkpoint.pt')
# 例子:加载调度器状态
# loaded_state = torch.load('scheduler_checkpoint.pt')
# scheduler.load_state_dict(loaded_state)
恢复调度器状态能够保证学习率从中断点继续其预设的进程,而不是不适当地重新开始热身或衰减阶段。
在记录核心模型和优化组件的同时,您还需要追踪训练运行的整体进度。这通常包括:
保存这些计数器能够让你准确知道在训练进程和数据集迭代的何处恢复。
为了严格复现性,特别是在研究环境或调试时,保存所使用的随机数生成器(例如 Python 的 random、NumPy 的 numpy.random 以及 PyTorch 的 torch.cuda.manual_seed_all 状态)的状态很重要。这能够确保数据混洗、dropout 模式以及训练流程中的任何其他随机元素在恢复时保持一致。
import torch
import random
import numpy as np
# 例子:保存 RNG 状态
rng_states = {
'python_rng_state': random.getstate(),
'numpy_rng_state': np.random.get_state(),
'torch_rng_state': torch.get_rng_state(),
'cuda_rng_state': torch.cuda.get_rng_state_all() # 保存所有 GPU 的状态
}
# torch.save(rng_states, 'rng_checkpoint.pt')
# 例子:加载 RNG 状态
# loaded_rng_states = torch.load('rng_checkpoint.pt')
# random.setstate(loaded_rng_states['python_rng_state'])
# np.random.set_state(loaded_rng_states['numpy_rng_state'])
# torch.set_rng_state(loaded_rng_states['torch_rng_state'])
# torch.cuda.set_rng_state_all(loaded_rng_states['cuda_rng_state'])
在实践中,所有这些状态组成部分通常会一起保存到单个字典或结构化文件格式中。这使得管理和加载检查点变得更简单。
# 假设模型、优化器、调度器、全局步数、当前 epoch 均已存在
checkpoint = {
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'global_step': global_step,
'epoch': current_epoch,
# 如果需要严格复现性,添加 RNG 状态
'rng_states': {
'python_rng_state': random.getstate(),
'numpy_rng_state': np.random.get_state(),
'torch_rng_state': torch.get_rng_state(),
'cuda_rng_state': torch.cuda.get_rng_state_all()
}
# (可选)包含其他元数据:损失、准确率、框架版本等
'loss': current_loss
}
# 保存合并后的检查点
checkpoint_path = f"checkpoint_step_{global_step}.pt"
# torch.save(checkpoint, checkpoint_path)
# --- 之后,在恢复时 ---
# loaded_checkpoint = torch.load(checkpoint_path)
# model.load_state_dict(
# loaded_checkpoint['model_state_dict'])
# optimizer.load_state_dict(
# loaded_checkpoint['optimizer_state_dict'])
# scheduler.load_state_dict(
# loaded_checkpoint['scheduler_state_dict'])
# global_step = loaded_checkpoint['global_step']
# current_epoch = loaded_checkpoint['epoch']
# current_loss = loaded_checkpoint.get('loss', None) # 处理可选键
# 如果已保存,恢复 RNG 状态
# rng_states = loaded_checkpoint.get('rng_states')
# if rng_states:
# random.setstate(rng_states['python_rng_state'])
# np.random.set_state(rng_states['numpy_rng_state'])
# torch.set_rng_state(rng_states['torch_rng_state'])
# torch.cuda.set_rng_state_all(rng_states['cuda_rng_state'])
# 现在可以继续训练循环
# 从恢复的状态
通过认真保存这些组成部分,你就为容错奠定了根本。当中断发生时,你可以自信地恢复完整的训练上下文,并以最小的干扰和计算浪费继续该过程。接下来,我们将讨论如何在分布式环境中有效管理这一过程,以及优化检查点频率和存储的策略。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造