趋近智
全参数 (parameter)微调 (fine-tuning)通常需要较长时间的训练,可能数小时、数天甚至数周,耗费大量计算资源。在这些长时间运行中,中断几乎无法避免,无论是硬件故障、集群作业抢占、网络问题,还是仅仅需要暂停和重启。如果没有保存和恢复进度的机制,这类中断可能导致之前所有训练工作完全丢失,浪费宝贵的时间和计算预算。检查点保存提供这种重要的安全保障。
有效的检查点保存需要定期将训练过程的完整状态保存到持久化存储中。这使您能够从中断处准确恢复训练,最大程度减少浪费。
为了恢复训练,您不仅仅需要保存模型的参数 (parameter)()。一个全面的检查点应包含:
model.state_dict() 访问。保存它能确保您保留已学习的信息。optimizer.state_dict() 对于保持这种动量是必要的。scheduler.state_dict() 保存。否则会重置调度,导致恢复时学习率不正确。torch.get_rng_state()、numpy.random.get_state()、random.getstate())可能会有帮助。恢复这些状态可确保数据管道在恢复后行为一致。检查点保存逻辑通常直接集成到训练循环中。您需要决定检查点保存的频率。常见做法包括:
为了管理存储,您可以只保留最新检查点、最佳检查点,或者保留最近几个检查点的滚动窗口。
这是一个类似PyTorch的检查点保存示例:
# 假设模型、优化器、调度器、迭代次数、全局步数已定义
checkpoint_path = f"./model_checkpoint_epoch_{epoch}_step_{global_step}.pt"
save_state = {
'epoch': epoch,
'global_step': global_step,
'model_state_dict': model.state_dict(), # 如果使用DDP,也可以是model.module.state_dict()
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
# 可选地添加RNG状态、配置、损失历史等
# 'rng_state': torch.get_rng_state(),
# 'config': training_args,
}
# 最佳实践:保存到临时文件,然后重命名以确保原子性
temp_path = checkpoint_path + ".tmp"
torch.save(save_state, temp_path)
os.rename(temp_path, checkpoint_path) # 原子性重命名
print(f"检查点已保存到 {checkpoint_path}")
# 可选地管理旧检查点(例如,只保留最后3个)
# manage_checkpoints(checkpoint_dir, keep_last=3)
使用临时文件并重命名可确保在保存过程中断时,不会得到损坏的部分检查点文件。
开始训练运行时,您的脚本应检查是否存在有效的检查点。如果存在,应在开始训练前加载已保存的状态。
以下是您加载状态的方式:
# 假设模型、优化器、调度器已初始化
# 确定要加载的检查点路径(例如,最新的一个)
checkpoint_path = find_latest_checkpoint("./") # 查找检查点文件的函数
if checkpoint_path:
print(f"从检查点恢复训练: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location='cpu') # 首先加载到CPU
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
start_epoch = checkpoint['epoch']
global_step = checkpoint['global_step']
# 可选地恢复RNG状态等
# torch.set_rng_state(checkpoint['rng_state'])
# 重要提示:如果使用GPU,将优化器状态移至正确的设备
# 根据优化器状态的保存方式以及是否最初将检查点加载到CPU,可能需要此步骤。
for state in optimizer.state.values():
for k, v in state.items():
if torch.is_tensor(v):
state[k] = v.to(device) # 设备可以是'cuda:0'
print(f"从迭代 {start_epoch}、全局步 {global_step} 恢复")
else:
print("未找到检查点,从头开始训练。")
start_epoch = 0
global_step = 0
# --- 从start_epoch开始训练循环,跟踪global_step ---
加载优化器和调度器状态对于维持训练过程的稳定性很重要。如果检查点很大,在加载到模型/优化器之前,首先将检查点加载到CPU(map_location='cpu')可以防止GPU内存问题。如果使用GPU,请确保在加载后将优化器状态张量移到正确的设备。
当使用PyTorch的分布式数据并行(DDP)等分布式训练框架时,检查点保存需要进行微小调整:
model.module.state_dict())。torch.distributed.barrier())来确保所有进程在继续训练之前都已加载状态。训练循环,其在开始时加载检查点并在过程中定期保存。中断会触发重启,然后加载最新检查点。
掌握检查点保存和恢复不仅是便利,也是必要,为了可靠地执行大型语言模型全参数 (parameter)微调 (fine-tuning)这一资源密集型过程。它确保进度在中断情况下得以保留,使长时间训练运行变得可行和易于管理。
简洁的语法。内置调试功能。从第一天起就可投入生产。
为 ApX 背后的 AI 系统而构建
这部分内容有帮助吗?
Trainer 类的文档解释了其在微调大型语言模型时强大的检查点功能,包括保存、加载和管理检查点的策略。© 2026 ApX Machine LearningAI伦理与透明度•