全参数微调通常需要较长时间的训练,可能数小时、数天甚至数周,耗费大量计算资源。在这些长时间运行中,中断几乎无法避免,无论是硬件故障、集群作业抢占、网络问题,还是仅仅需要暂停和重启。如果没有保存和恢复进度的机制,这类中断可能导致之前所有训练工作完全丢失,浪费宝贵的时间和计算预算。检查点保存提供这种重要的安全保障。有效的检查点保存需要定期将训练过程的完整状态保存到持久化存储中。这使您能够从中断处准确恢复训练,最大程度减少浪费。训练状态包含什么?为了恢复训练,您不仅仅需要保存模型的参数($\theta$)。一个全面的检查点应包含:模型状态字典: 这包含模型所有可学习的参数(权重和偏差)。在PyTorch等框架中,这通常通过 model.state_dict() 访问。保存它能确保您保留已学习的信息。优化器状态字典: 现代优化器,如Adam或AdamW,维护着内部状态,例如梯度和平方梯度的移动平均值(动量)。简单地恢复模型权重并初始化一个新的优化器会重置这些状态,可能扰乱到目前为止的学习过程和收敛表现。保存 optimizer.state_dict() 对于保持这种动量是必要的。学习率调度器状态: 如果您正在使用学习率调度器(例如,线性预热后衰减),其内部状态(如当前学习率乘数、步数或上次迭代)必须使用 scheduler.state_dict() 保存。否则会重置调度,导致恢复时学习率不正确。训练进度指示: 重要的元数据,例如当前迭代次数、已完成的训练步数或已处理的样本数量。这使训练循环能够正确判断从何处重新开始数据加载和迭代计数。随机数生成器(RNG)状态: 为了严格的可复现性,特别是如果您的数据加载、混洗或增强涉及随机性,保存RNG的状态(例如,torch.get_rng_state()、numpy.random.get_state()、random.getstate())可能会有帮助。恢复这些状态可确保数据管道在恢复后行为一致。(可选) 损失历史和评估指标: 保存训练过程中跟踪的性能指标(例如,训练损失、验证准确性)直到检查点,有助于监控和分析,尽管对于恢复训练并非硬性要求。(可选) 训练配置: 保存运行所用的配置参数或脚本参数,有助于在略有不同的环境或代码更改后恢复时保持一致性。实现检查点保存检查点保存逻辑通常直接集成到训练循环中。您需要决定检查点保存的频率。常见做法包括:每N步保存: 提供细粒度的恢复点,但如果N很小,可能导致更高的I/O开销。每轮迭代保存: 更易于管理,但如果中断发生在迭代后期,可能意味着损失整轮迭代的工作。基于时间间隔保存: 例如,每小时。保存表现最佳的检查点: 基于验证指标,确保始终保留迄今为止表现最佳的模型状态。这通常与保存最新检查点相结合。为了管理存储,您可以只保留最新检查点、最佳检查点,或者保留最近几个检查点的滚动窗口。这是一个类似PyTorch的检查点保存示例:# 假设模型、优化器、调度器、迭代次数、全局步数已定义 checkpoint_path = f"./model_checkpoint_epoch_{epoch}_step_{global_step}.pt" save_state = { 'epoch': epoch, 'global_step': global_step, 'model_state_dict': model.state_dict(), # 如果使用DDP,也可以是model.module.state_dict() 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), # 可选地添加RNG状态、配置、损失历史等 # 'rng_state': torch.get_rng_state(), # 'config': training_args, } # 最佳实践:保存到临时文件,然后重命名以确保原子性 temp_path = checkpoint_path + ".tmp" torch.save(save_state, temp_path) os.rename(temp_path, checkpoint_path) # 原子性重命名 print(f"检查点已保存到 {checkpoint_path}") # 可选地管理旧检查点(例如,只保留最后3个) # manage_checkpoints(checkpoint_dir, keep_last=3)使用临时文件并重命名可确保在保存过程中断时,不会得到损坏的部分检查点文件。从检查点恢复开始训练运行时,您的脚本应检查是否存在有效的检查点。如果存在,应在开始训练前加载已保存的状态。以下是您加载状态的方式:# 假设模型、优化器、调度器已初始化 # 确定要加载的检查点路径(例如,最新的一个) checkpoint_path = find_latest_checkpoint("./") # 查找检查点文件的函数 if checkpoint_path: print(f"从检查点恢复训练: {checkpoint_path}") checkpoint = torch.load(checkpoint_path, map_location='cpu') # 首先加载到CPU model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) scheduler.load_state_dict(checkpoint['scheduler_state_dict']) start_epoch = checkpoint['epoch'] global_step = checkpoint['global_step'] # 可选地恢复RNG状态等 # torch.set_rng_state(checkpoint['rng_state']) # 重要提示:如果使用GPU,将优化器状态移至正确的设备 # 根据优化器状态的保存方式以及是否最初将检查点加载到CPU,可能需要此步骤。 for state in optimizer.state.values(): for k, v in state.items(): if torch.is_tensor(v): state[k] = v.to(device) # 设备可以是'cuda:0' print(f"从迭代 {start_epoch}、全局步 {global_step} 恢复") else: print("未找到检查点,从头开始训练。") start_epoch = 0 global_step = 0 # --- 从start_epoch开始训练循环,跟踪global_step ---加载优化器和调度器状态对于维持训练过程的稳定性很重要。如果检查点很大,在加载到模型/优化器之前,首先将检查点加载到CPU(map_location='cpu')可以防止GPU内存问题。如果使用GPU,请确保在加载后将优化器状态张量移到正确的设备。分布式训练的注意事项当使用PyTorch的分布式数据并行(DDP)等分布式训练框架时,检查点保存需要进行微小调整:保存: 通常,只有主进程(rank 0)应保存检查点,以避免重复写入和潜在的竞争条件。模型状态字典可能需要以不同的方式访问(例如,DDP中为model.module.state_dict())。加载: 所有进程都需要加载模型权重以确保一致性。但是,迭代次数和步数等元数据可能只被主进程需要。使用同步原语(如torch.distributed.barrier())来确保所有进程在继续训练之前都已加载状态。最佳实践频率: 选择一个能平衡恢复能力与I/O开销的频率。对于非常大的模型和长时间训练,每几百或几千步保存一次可能比较合适。原子性: 始终先保存到临时文件,然后重命名,以防止文件损坏。存储位置: 使用可靠的存储。网络文件系统(NFS)或云存储桶(S3、GCS、Azure Blob Storage)通常优于本地磁盘,特别是在节点可能短暂存在的集群环境中。验证: 如果数据损坏是主要顾虑,可以考虑在保存后添加简单检查(例如,确认文件大小或尝试快速重新加载)。代码版本控制: 请注意,在保存和加载检查点之间,模型架构或库版本的显著更改可能导致兼容性问题。如果可能,将相关的版本信息存储在检查点中。digraph G { rankdir=TB; node [shape=box, style="rounded,filled", fontname="Arial", margin=0.1, fillcolor="#e9ecef", color="#adb5bd"]; edge [fontname="Arial", fontsize=10]; Start [label="开始训练运行", fillcolor="#a5d8ff", shape=ellipse]; CheckCP [label="存在检查点?", shape=diamond, fillcolor="#ffe066"]; LoadCP [label="加载检查点\n(模型、优化器、调度器、步数)", fillcolor="#ffc9c9"]; InitState [label="初始化状态\n(迭代次数=0, 步数=0)", fillcolor="#b2f2bb"]; TrainStep [label="执行训练步\n(前向、反向、优化)", fillcolor="#bac8ff"]; CheckSave [label="是否需要保存检查点?\n(步数/迭代/时间)", shape=diamond, fillcolor="#ffe066"]; SaveCP [label="保存检查点\n(模型、优化器、调度器、步数)", fillcolor="#d8f5a2"]; CheckDone [label="训练完成?", shape=diamond, fillcolor="#ffe066"]; End [label="结束训练", fillcolor="#a5d8ff", shape=ellipse]; Failure [label="中断!\n(崩溃/抢占)", shape=cds, fillcolor="#ff8787", style="filled", penwidth=1.5]; Start -> CheckCP; CheckCP -> LoadCP [label=" 是 "]; CheckCP -> InitState [label=" 否 "]; LoadCP -> TrainStep; InitState -> TrainStep; TrainStep -> CheckSave; CheckSave -> SaveCP [label=" 是 "]; CheckSave -> CheckDone [label=" 否 "]; SaveCP -> CheckDone; CheckDone -> TrainStep [label=" 否 "]; CheckDone -> End [label=" 是 "]; // Failure Path TrainStep -> Failure [style=dashed, color="#f03e3e", arrowhead=open, constraint=false, label=" 潜在\n故障 "]; SaveCP -> Failure [style=dashed, color="#f03e3e", arrowhead=open, constraint=false, label=" 潜在\n故障 "]; Failure -> Start [style=dashed, color="#1c7ed6", arrowhead=open, label=" 重启脚本 "]; }训练循环,其在开始时加载检查点并在过程中定期保存。中断会触发重启,然后加载最新检查点。掌握检查点保存和恢复不仅是便利,也是必要,为了可靠地执行大型语言模型全参数微调这一资源密集型过程。它确保进度在中断情况下得以保留,使长时间训练运行变得可行和易于管理。