趋近智
大师班
决定保存检查点的频率以及如何管理生成的文件,需要权衡几个相互制约的因素:因故障而丢失计算进度的风险、保存过程本身带来的开销,以及存储的成本和可用性。恰当的平衡对高效、可靠的大规模模型训练很有必要。
选择检查点频率时,主要的权衡点在于最大限度减少故障时可能丢失的工作量与最大限度减少保存期间产生的开销。
有几个因素会影响您特定设置的最佳频率:
触发检查点的常见策略包括:
N个训练步保存一次。这提供了训练进度方面的可预测间隔。
# PyTorch训练循环中的示例
import torch
import os
SAVE_EVERY_N_STEPS = 1000
checkpoint_dir = "/path/to/checkpoints"
global_step = 0 # 假设此计数器在每个训练步骤中递增
# 在您的训练循环内部...
# optimizer.step()
# scheduler.step()
global_step += 1
if global_step % SAVE_EVERY_N_STEPS == 0:
# 构建检查点状态字典
state = {
'step': global_step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
# 添加调度器状态、随机状态等。
}
checkpoint_path = os.path.join(checkpoint_dir, f"step_{global_step}.pt")
print(f"Saving checkpoint to {checkpoint_path} at step {global_step}")
# 在实际场景中,使用保存函数,可能异步
# torch.save(state, checkpoint_path)
# 考虑在此处添加存储管理逻辑(见下文)
pass # 实际保存和管理的占位符
H小时保存一次。这实现起来简单,但在检查点之间训练进度的数量上可预测性较低,因为迭代速率可能变化。N步保存一次,或 每H小时保存一次,以先发生者为准。这为防止进度过慢和长时间未保存提供了保障。通常需要进行实验。从一个合理的频率开始(例如,每1000-5000步或每1-2小时),并根据观察到的稳定性及测量的开销进行调整。
LLM检查点,包含模型权重、优化器状态,并可能包含梯度统计信息(特别是使用ZeRO Stage 3时),会非常大,根据模型大小和分布式训练策略,可从千兆字节到太兆字节不等。在长时间训练运行中创建的每一个检查点都进行存储通常不切实际,因为有存储成本和容量限制。
存储位置的权衡:
保留策略:
由于存储所有检查点不可行,您需要一个策略来决定保留哪些、丢弃哪些。
K个: 保留最近的K个检查点。当保存检查点N时,删除检查点N-K。
K(例如,K=3或K=5)。仍然主要基于时间,不一定基于性能。M个: 定期监控一个验证指标(例如,困惑度)。保存与迄今观察到的M个最佳验证分数相关的检查点。这通常补充了保留最新检查点(s)的做法。
M个检查点。实施保留通常包括列出存储位置中现有的检查点,根据所选标准(步数、时间戳、验证分数)对其进行排序,并删除超出保留窗口的检查点。
# 保留最后K个检查点的示例逻辑
import os
import glob
import re
checkpoint_dir = "/path/to/checkpoints"
KEEP_LAST_K = 3
def manage_checkpoints(checkpoint_dir, keep_last_k):
"""删除较旧的检查点,只保留指定数量。"""
checkpoints = glob.glob(os.path.join(checkpoint_dir, "step_*.pt"))
# 提取步数,处理可能不匹配的文件
steps = []
for ckpt in checkpoints:
match = re.search(r"step_(\d+)\.pt$", os.path.basename(ckpt))
if match:
steps.append((int(match.group(1)), ckpt))
# 按步数排序(降序)
steps.sort(key=lambda x: x[0], reverse=True)
# 确定要删除的检查点
if len(steps) > keep_last_k:
checkpoints_to_delete = [ckpt_path for step, ckpt_path in steps[keep_last_k:]]
print(f"Found {len(steps)} checkpoints. Deleting {len(checkpoints_to_delete)} older checkpoints.")
for ckpt_path in checkpoints_to_delete:
try:
os.remove(ckpt_path)
print(f"Deleted {ckpt_path}")
except OSError as e:
print(f"Error deleting {ckpt_path}: {e}")
# 在成功保存新检查点后调用此函数
# manage_checkpoints(checkpoint_dir, KEEP_LAST_K)
流程展示了在保存新检查点后进行的检查点保留策略检查。
最终,频率和存储管理的选择取决于对您的训练环境的稳定性、性能特点、计算成本以及对潜在数据丢失的容忍度的仔细评估。使用明确定义的保存机制,结合基于新近度和性能的保留策略,是大型LLM训练的常见做法。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造