趋近智
大师班
训练大型语言模型(LLM)并非一个快速过程。与可能在数小时内完成训练的小型模型不同,预训练一个先进的LLM通常需要连续运行大型加速器(GPU或TPU)集群数天、数周乃至数月。设想一个训练任务,使用1024块高端GPU持续30天。这表示超过730,000加速器小时。这种长时间、大规模的运行显著增加了遇到潜在中断的可能。
大规模分布式系统的现实是故障时有发生。在长时间运行中,遇到问题的可能性接近必然。这些中断可能源于多种情况:
如果没有定期保存进度的机制,任何此类中断都会强制整个训练过程从头开始。这会带来严重后果:
这就是检查点变得必不可少的地方。检查点是将训练作业的完整状态定期保存到持久存储(如分布式文件系统或云存储)的做法。这种状态不仅包括模型的参数(权重),还包括精确地从上次停止的地方恢复训练所需的一切,例如优化器的状态、学习率调度器的状态、当前的训练迭代或周期数,以及数据加载器的状态。
设想一个简化的训练循环:
# 未使用检查点的例子
import torch
import torch.optim as optim
model = MyLargeModel()
optimizer = optim.AdamW(model.parameters(), lr=1e-4)
# 假设 data_loader 提供数据批次
for step in range(TOTAL_TRAINING_STEPS):
# --- 潜在故障点 ---
if some_failure_condition():
print("发生故障!从第 0 步重新开始。")
# 到 'step' 之前的所有进度都已丢失。
# 需要重新初始化模型、优化器,并从第 0 步开始。
raise SystemExit("训练失败")
batch = next(iter(data_loader))
outputs = model(batch['input_ids'])
loss = calculate_loss(outputs, batch['labels'])
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step % LOG_INTERVAL == 0:
print(f"步数: {step}, 损失: {loss.item()}")
print("训练成功完成!") # 仅在没有故障发生时才能到达
如果在一个百万步的训练运行中,第500,000步发生故障,为这500,000步所做的工作都白费了。检查点机制引入了保存点:
# 使用检查点的例子
import torch
import torch.optim as optim
import os
CHECKPOINT_DIR = "/path/to/persistent/storage/checkpoints"
CHECKPOINT_FREQ = 1000 # 每 1000 步保存一次
def save_checkpoint(model, optimizer, step, filename="checkpoint.pt"):
checkpoint_path = os.path.join(CHECKPOINT_DIR, f"step_{step}_{filename}")
state = {
'step': step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
# 添加调度器状态、随机状态等。
}
torch.save(state, checkpoint_path)
print(f"在第 {step} 步保存检查点至 {checkpoint_path}")
def load_checkpoint(model, optimizer):
# 寻找最新检查点的逻辑
latest_checkpoint_path = find_latest_checkpoint(CHECKPOINT_DIR)
if latest_checkpoint_path:
checkpoint = torch.load(latest_checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_step = checkpoint['step'] + 1
print(f"从第 {start_step} 步的检查点恢复")
return start_step
else:
print("未找到检查点,从头开始。")
return 0
model = MyLargeModel()
optimizer = optim.AdamW(model.parameters(), lr=1e-4)
start_step = load_checkpoint(model, optimizer) # 尝试恢复
for step in range(start_step, TOTAL_TRAINING_STEPS):
# --- 潜在故障点 ---
try:
batch = next(iter(data_loader))
outputs = model(batch['input_ids'])
loss = calculate_loss(outputs, batch['labels'])
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step % LOG_INTERVAL == 0:
print(f"步数: {step}, 损失: {loss.item()}")
# --- 定期保存进度 ---
if step % CHECKPOINT_FREQ == 0 and step > 0:
save_checkpoint(model, optimizer, step)
except Exception as e:
print(f"在第 {step} 步发生故障: {e}")
print("退出。重新运行脚本以从最新检查点恢复。")
raise SystemExit("训练中断")
print("训练成功完成!")
在这个修改后的循环中,如果发生故障,load_checkpoint函数(其实现细节我们稍后会讨论)可以从最近保存的检查点恢复状态,允许训练从例如第500,001步而不是第0步恢复。这大幅减少了计算资源的浪费。
尽管检查点机制的主要原因是应对意外故障的容错能力,但它也提供了操作灵活性。检查点允许计划性停机,例如为了集群的计划维护或重新配置训练作业。它们也能使训练在不同硬件上恢复,或通过从中间状态分叉来寻找不同训练路径。
考虑到LLM训练所需的巨大时间和资源投入,检查点不仅仅是便利;它是成功完成这些要求高的项目的基本要求。接下来的章节将详细说明需要保存的组件、在分布式环境中管理检查点的策略以及实现的最佳实践。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造