趋近智
训练大型语言模型通常是一场资源密集型持久战,而非短跑。任务可能在数百或数千个加速器上运行数天、数周甚至数月。在这些长期运行、复杂的分布式环境中,故障不仅可能发生,而且很有可能发生。硬件可能出故障,网络可能中断,抢占式实例可能被回收,或者软件错误可能意外出现。如果没有机制来平稳地处理这些中断,单次故障可能导致数周的计算白费,造成无法接受的延迟和成本超支。在这种情况下,检查点与容错机制成为您大模型运维策略中必不可少的组成部分。
检查点是指定期保存训练任务状态的做法。这种状态通常不仅包括模型的参数(权重),还包括优化器的状态(例如,Adam中的动量缓冲区)、当前的训练周期或步数、学习率调度器的状态,甚至可能包括数据加载器和随机数生成器的状态。保存这个完整的状态允许训练过程从中断的确切点恢复,而不是从头开始。
有效的检查点在大规模训练中提供了几项显著好处:
实施有效的检查点策略涉及几个决定:
频率: 您应该多久保存一次检查点?这需要权衡。
存储: 大型模型的检查点容量很大,从几十吉字节到数太字节不等。
格式与内容: 具体保存什么?
state_dict()或等效的模型参数。异步检查点: 为了最大限度减少对训练时间的影响,一些框架允许检查点在后台异步进行,将I/O操作与正在进行的计算重叠。这需要仔细实施以确保状态一致性。
检查点是基础,但一个真正容错的系统涉及更多:
以下是使用PyTorch类似语法的示例:
# --- 保存检查点 ---
def save_checkpoint(model, optimizer, scheduler, epoch, step, save_dir, is_best=False):
"""保存模型、优化器、调度器和训练进度。"""
os.makedirs(save_dir, exist_ok=True)
checkpoint_path = os.path.join(save_dir, f"checkpoint_step_{step}.pt")
# 收集状态字典
# 注意:对于分布式模型(DDP、FSDP、DeepSpeed),请使用相应的API
# 在秩0上收集完整状态字典或以分片格式保存。
model_state = model.state_dict()
optimizer_state = optimizer.state_dict()
scheduler_state = scheduler.state_dict()
torch.save({
'model_state_dict': model_state,
'optimizer_state_dict': optimizer_state,
'scheduler_state_dict': scheduler_state,
'epoch': epoch,
'step': step,
# 可以添加随机数生成器状态、损失值等。
}, checkpoint_path)
print(f"检查点已保存到 {checkpoint_path}")
if is_best:
best_path = os.path.join(save_dir, "checkpoint_best.pt")
shutil.copyfile(checkpoint_path, best_path)
print(f"最佳检查点已更新到步数 {step}")
# --- 加载检查点 ---
def load_checkpoint(model, optimizer, scheduler, load_path):
"""从检查点文件加载状态。"""
if not os.path.exists(load_path):
print(f"检查点文件未找到: {load_path}")
return 0, 0 # 返回起始周期和步数
# 将检查点加载到合适的设备上
# 使用 map_location 实现灵活性(例如,将GPU检查点加载到CPU)
checkpoint = torch.load(load_path, map_location=torch.device('cpu'))
# 加载状态
# 同样,如果适用,请使用分布式框架API
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
start_epoch = checkpoint['epoch']
start_step = checkpoint['step'] + 1 # 从下一步继续
print(f"已从 {load_path} 加载检查点。从周期 {start_epoch},步数 {start_step} 继续。")
return start_epoch, start_step
# --- 训练循环中的使用示例 ---
# 初始化模型、优化器、调度器...
start_epoch = 0
global_step = 0
checkpoint_dir = "/path/to/checkpoints"
resume_from_checkpoint = "/path/to/checkpoints/checkpoint_latest.pt" # 或找到最新逻辑
if os.path.exists(resume_from_checkpoint):
start_epoch, global_step = load_checkpoint(model, optimizer, scheduler, resume_from_checkpoint)
for epoch in range(start_epoch, num_epochs):
for batch in dataloader:
# 训练步骤...
loss = train_step(model, batch)
# 更新优化器、调度器...
optimizer.step()
scheduler.step()
if global_step % checkpoint_interval == 0:
save_checkpoint(model, optimizer, scheduler, epoch, global_step, checkpoint_dir)
# 如果需要,更新最新检查点的符号链接/标记
# 可选:根据保留策略删除旧检查点
if global_step % validation_interval == 0:
validation_loss = evaluate(model, validation_dataloader)
if validation_loss < best_validation_loss:
best_validation_loss = validation_loss
save_checkpoint(model, optimizer, scheduler, epoch, global_step, checkpoint_dir, is_best=True)
global_step += 1
请注意:上述代码仅供参考。实际实现,特别是对于使用DeepSpeed或FSDP等框架进行的分布式训练,需要那些框架提供的特定API才能正确处理分片状态。 例如,DeepSpeed提供了model_engine.save_checkpoint(save_dir)和model_engine.load_checkpoint(load_dir)。
大型语言模型的庞大体积给检查点带来了特定的挑战:
分片检查点保存示意图。每个秩将其部分的模型和优化器状态保存到共享存储上检查点目录中的一个专用文件。元数据协调这些分片。
总之,检查点与容错对于严肃的大模型训练和微调操作来说并非可选项。它们是有效管理长期运行、资源密集型任务的基本要求。通过实施全面的状态保存和恢复策略,与容错分布式框架集成,并高效管理检查点存储,您可以显著提高您大模型运维流程的可靠性和成本效益。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造