训练大型语言模型(LLM)并非一个快速过程。与可能在数小时内完成训练的小型模型不同,预训练一个先进的LLM通常需要连续运行大型加速器(GPU或TPU)集群数天、数周乃至数月。设想一个训练任务,使用1024块高端GPU持续30天。这表示超过730,000加速器小时。这种长时间、大规模的运行显著增加了遇到潜在中断的可能。大规模分布式系统的现实是故障时有发生。在长时间运行中,遇到问题的可能性接近必然。这些中断可能源于多种情况:硬件故障: 单个GPU或TPU可能出现故障。整个计算节点可能因硬件故障(内存错误、电源问题)而崩溃。节点之间的网络连接,对分布式训练通信很要紧,可能变得不可靠或完全失效。软件异常: 训练脚本、深度学习框架(如PyTorch)、底层库(CUDA、NCCL)乃至操作系统中都可能出现程序错误。组件间意外的交互可能导致崩溃或死锁。如果资源使用意外激增,可能发生内存不足(OOM)错误。基础设施问题: 托管计算集群的数据中心可能出现电源波动或停电。集群的计划内维护可能需要停止作业。如果使用基于云的抢占式实例或竞价实例来管理成本,这些实例可能在几乎没有通知的情况下被云服务商收回。如果没有定期保存进度的机制,任何此类中断都会强制整个训练过程从头开始。这会带来严重后果:计算资源浪费: 在故障点之前进行的所有计算都将丢失。在我们1024块GPU的例子中,即使训练进行到一半发生故障,也将白白损失超过365,000 GPU小时,这代表了巨大的财务成本和能源消耗。时间损失: 训练运行通常处于研究项目或产品开发周期的关键路径上。从头开始会引入明显延误,可能以天或周计算,影响项目进度,并可能阻碍竞争优势。挫败感增加: 调试大型分布式作业中的故障本就复杂。不得不反复重新开始会给相关工程师增加巨大的额外负担和挫败感。这就是检查点变得必不可少的地方。检查点是将训练作业的完整状态定期保存到持久存储(如分布式文件系统或云存储)的做法。这种状态不仅包括模型的参数(权重),还包括精确地从上次停止的地方恢复训练所需的一切,例如优化器的状态、学习率调度器的状态、当前的训练迭代或周期数,以及数据加载器的状态。设想一个简化的训练循环:# 未使用检查点的例子 import torch import torch.optim as optim model = MyLargeModel() optimizer = optim.AdamW(model.parameters(), lr=1e-4) # 假设 data_loader 提供数据批次 for step in range(TOTAL_TRAINING_STEPS): # --- 潜在故障点 --- if some_failure_condition(): print("发生故障!从第 0 步重新开始。") # 到 'step' 之前的所有进度都已丢失。 # 需要重新初始化模型、优化器,并从第 0 步开始。 raise SystemExit("训练失败") batch = next(iter(data_loader)) outputs = model(batch['input_ids']) loss = calculate_loss(outputs, batch['labels']) optimizer.zero_grad() loss.backward() optimizer.step() if step % LOG_INTERVAL == 0: print(f"步数: {step}, 损失: {loss.item()}") print("训练成功完成!") # 仅在没有故障发生时才能到达如果在一个百万步的训练运行中,第500,000步发生故障,为这500,000步所做的工作都白费了。检查点机制引入了保存点:# 使用检查点的例子 import torch import torch.optim as optim import os CHECKPOINT_DIR = "/path/to/persistent/storage/checkpoints" CHECKPOINT_FREQ = 1000 # 每 1000 步保存一次 def save_checkpoint(model, optimizer, step, filename="checkpoint.pt"): checkpoint_path = os.path.join(CHECKPOINT_DIR, f"step_{step}_{filename}") state = { 'step': step, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), # 添加调度器状态、随机状态等。 } torch.save(state, checkpoint_path) print(f"在第 {step} 步保存检查点至 {checkpoint_path}") def load_checkpoint(model, optimizer): # 寻找最新检查点的逻辑 latest_checkpoint_path = find_latest_checkpoint(CHECKPOINT_DIR) if latest_checkpoint_path: checkpoint = torch.load(latest_checkpoint_path) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) start_step = checkpoint['step'] + 1 print(f"从第 {start_step} 步的检查点恢复") return start_step else: print("未找到检查点,从头开始。") return 0 model = MyLargeModel() optimizer = optim.AdamW(model.parameters(), lr=1e-4) start_step = load_checkpoint(model, optimizer) # 尝试恢复 for step in range(start_step, TOTAL_TRAINING_STEPS): # --- 潜在故障点 --- try: batch = next(iter(data_loader)) outputs = model(batch['input_ids']) loss = calculate_loss(outputs, batch['labels']) optimizer.zero_grad() loss.backward() optimizer.step() if step % LOG_INTERVAL == 0: print(f"步数: {step}, 损失: {loss.item()}") # --- 定期保存进度 --- if step % CHECKPOINT_FREQ == 0 and step > 0: save_checkpoint(model, optimizer, step) except Exception as e: print(f"在第 {step} 步发生故障: {e}") print("退出。重新运行脚本以从最新检查点恢复。") raise SystemExit("训练中断") print("训练成功完成!")在这个修改后的循环中,如果发生故障,load_checkpoint函数(其实现细节我们稍后会讨论)可以从最近保存的检查点恢复状态,允许训练从例如第500,001步而不是第0步恢复。这大幅减少了计算资源的浪费。尽管检查点机制的主要原因是应对意外故障的容错能力,但它也提供了操作灵活性。检查点允许计划性停机,例如为了集群的计划维护或重新配置训练作业。它们也能使训练在不同硬件上恢复,或通过从中间状态分叉来寻找不同训练路径。考虑到LLM训练所需的巨大时间和资源投入,检查点不仅仅是便利;它是成功完成这些要求高的项目的基本要求。接下来的章节将详细说明需要保存的组件、在分布式环境中管理检查点的策略以及实现的最佳实践。