趋近智
大师班
成功保存检查点只是成功了一半;从该检查点正确恢复训练的能力对于实现容错的好处同等重要。简单的恢复,可能只加载模型权重,会导致次优的训练过程或不正确的结果。真正的恢复需要将整个训练状态恢复到中断前的精确位置。
恢复完整的训练状态不仅涉及加载模型参数,还包括优化器、学习率调度器以及可能的DataLoader进度和随机数生成器状态。未能恢复这些组件中的任何一个都可能会使后续的训练过程失效。例如,如果自适应优化器(如AdamW)在没有其累积动量和方差估计的情况下重新启动,将有效地重置其学习路径,可能导致之前取得的重大进展失效。同样,从头开始重启带有预热和衰减的学习率调度器会显著改变优化路径。
我们来研究一下实现恢复机制所涉及的步骤。
首先,你的训练脚本需要有逻辑来检测是否请求恢复操作,这通常通过命令行参数或指定检查点路径的配置设置来完成。通常自动寻找指定目录中最新的有效检查点是很实用的。
import torch
import os
import glob
def find_latest_checkpoint(checkpoint_dir):
"""根据迭代次数查找最新的检查点文件。"""
list_of_files = glob.glob(os.path.join(checkpoint_dir, 'checkpoint_*.pt'))
if not list_of_files:
return None
# 假设文件名格式 'checkpoint_iter_XXXXX.pt'
latest_file = max(
list_of_files,
key=lambda f: int(f.split('_')[-1].split('.')[0])
)
return latest_file
# --- 在你的主训练脚本中 ---
# config.resume_from_checkpoint = True 或 False
# config.checkpoint_dir = '/检查点路径'
# config.resume_checkpoint_path = None # 可选:指定确切路径
resume_path = None
if config.resume_from_checkpoint:
if config.resume_checkpoint_path:
resume_path = config.resume_checkpoint_path
else:
resume_path = find_latest_checkpoint(config.checkpoint_dir)
if resume_path and os.path.isfile(resume_path):
print(f"从检查点恢复训练:{resume_path}")
checkpoint = torch.load(resume_path, map_location='cpu') # 先加载到CPU
else:
print("从头开始训练。")
checkpoint = None
在将组件移动到目标设备之前,先将检查点加载到CPU (map_location='cpu') 可以防止GPU内存突然升高,尤其是在多GPU配置中。
加载检查点字典后,你需要恢复核心训练组件的状态。
# 假设模型、优化器和调度器已经初始化
# (如同从头开始训练时那样)
model = YourTransformerModel(config)
optimizer = torch.optim.AdamW(
model.parameters(), lr=config.learning_rate, ...
)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, ...) # 示例调度器
start_iter = 0
best_val_loss = float('inf')
if checkpoint is not None:
# 恢复模型状态
# 如果架构略有变化,处理潜在不匹配(使用strict=False)
# 注意:为精确恢复,推荐使用strict=True。
model.load_state_dict(
checkpoint['model_state_dict'], strict=True
)
print("模型状态已加载。")
# 恢复优化器状态
# 对自适应优化器和学习率动量很重要
if 'optimizer_state_dict' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
print("优化器状态已加载。")
else:
print(
"警告:检查点中未找到优化器状态。"
"从头开始初始化优化器。"
)
# 恢复学习率调度器状态
if 'scheduler_state_dict' in checkpoint:
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
print("调度器状态已加载。")
else:
print(
"警告:检查点中未找到调度器状态。"
"从头开始初始化调度器。"
)
# 恢复训练进度
if 'iteration' in checkpoint:
start_iter = checkpoint['iteration'] + 1 # 从下一次迭代开始
print(f"从迭代 {start_iter} 恢复。")
if 'best_val_loss' in checkpoint:
best_val_loss = checkpoint['best_val_loss']
print(f"已加载最佳验证损失:{best_val_loss:.4f}")
# 恢复RNG状态以保证可复现性(可选但建议)
if 'rng_states' in checkpoint:
torch.set_rng_state(
checkpoint['rng_states']['torch_rng_state']
)
# 也可能需要恢复numpy和python的随机状态
# import numpy as np
# np.random.set_state(checkpoint['rng_states']['numpy_rng_state'])
# import random
# random.setstate(checkpoint['rng_states']['python_rng_state'])
print("RNG状态已加载。")
# 在加载状态字典后,将模型移动到目标设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# 重要:在加载后,将优化器状态移动到正确的设备
# 一些框架会自动处理,但在原生PyTorch中需要注意
for state in optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.to(device)
print(f"训练将从迭代 {start_iter} 开始/恢复")
一个经常被忽略的重要事项是确保优化器的状态张量在加载状态字典后被移动到正确的设备上。虽然模型参数通过model.to(device)移动,但优化器状态(例如AdamW中的动量缓冲区)驻留在优化器对象中,可能需要明确的设备放置。
这可能是恢复过程中最复杂的部分。简单地在恢复时从数据集的开头重新启动数据加载器意味着你将重新处理在该周期内中断前已见过的数据样本。这会使该周期的训练数据分布出现偏差并延缓进度。
理想情况下,你希望数据加载器正好从中断的地方继续。策略包括:
# --- 在训练循环内部,循环开始之前 ---
# 假设 'train_dataloader' 已初始化(可能使用基于周期的种子)
# 我们需要知道我们正在恢复到当前周期的哪次迭代
# 计算周期数和周期内的迭代次数
# 假设适用 'config.gradient_accumulation_steps'
effective_batch_size = config.batch_size * config.num_gpus # 为DP/DDP调整
# 如果使用梯度累积,effective_batch_size不会改变,
# 但每周期的步数可能会变。我们假设步数基于优化器步数。
iterations_per_epoch = len(train_dataloader)
# 或者根据数据集大小 / effective_batch_size计算
start_epoch = start_iter // iterations_per_epoch
resume_iter_within_epoch = start_iter % iterations_per_epoch
print(
f"恢复到周期 {start_epoch},从周期内的迭代 "
f"{resume_iter_within_epoch} 开始。"
)
# --- 在周期循环内部 ---
for epoch in range(start_epoch, config.num_epochs):
# 如果需要,重新设定数据加载器采样器种子以保证可复现性
# train_dataloader.sampler.set_epoch(epoch) # 如果使用DistributedSampler
data_iter = iter(train_dataloader)
# 如果在当前周期内恢复,跳过已处理的批次
if epoch == start_epoch and resume_iter_within_epoch > 0:
print(
f"在周期 {epoch} 中跳过 {resume_iter_within_epoch} 个批次,以恢复状态..."
)
for _ in range(resume_iter_within_epoch):
try:
next(data_iter)
except StopIteration:
# 如果检查点逻辑正确,不应发生
print("错误:尝试跳过数据加载器末尾。")
break
print("跳过完成。")
# 现在开始当前周期的实际训练迭代
for step_in_epoch in range(
resume_iter_within_epoch, iterations_per_epoch
):
current_global_iter = epoch * iterations_per_epoch + step_in_epoch
# 获取批次(如果跳过出错,处理潜在的StopIteration)
try:
batch = next(data_iter)
except StopIteration:
print(
f"警告:在周期 {epoch} 的步进 "
f"{step_in_epoch} 处,数据加载器意外耗尽。"
)
break
# ... 训练步进的其余部分:将批次移动到设备,正向传播,
# 反向传播,优化器步进 ...
# 为后续周期重置恢复标记
if step_in_epoch == resume_iter_within_epoch:
resume_iter_within_epoch = 0
# 确保如果内部循环提前完成,标记会被重置
resume_iter_within_epoch = 0
这种跳过机制确保模型在整个训练运行中大致看到每个数据样本预期的次数,保持训练过程的完整性。像torch.utils.data.DataLoader与DistributedSampler等采样器结合使用的库,通常需要仔细处理周期种子设定(sampler.set_epoch(epoch)),以确保在分布式设置中数据正确混洗和分配,尤其是在恢复时。
在分布式训练环境(DDP、FSDP、ZeRO)中,恢复需要仔细协调:
torch.distributed.barrier()来确保所有进程在加载之前都已定位到检查点。load_checkpoint函数通常会自动为每个进程加载相应的分片。直接使用torch.load和optimizer.load_state_dict可能无法正确处理这些分片状态;请依赖框架的工具。DistributedSampler时,确保在恢复时正确调用set_epoch(),并且跳过逻辑考虑了每个进程的数据分片。每个进程在其自己的数据部分中跳过批次。# 使用DeepSpeed的检查点加载示例(简化)
# 假设 'model_engine' 是DeepSpeed引擎,封装了模型、优化器等。
# 在你的设置代码中的某个地方:
load_path, client_state = model_engine.load_checkpoint(
config.checkpoint_dir, tag=config.checkpoint_tag
)
if load_path is not None:
print(f"从检查点 {load_path} 恢复训练")
# DeepSpeed的load_checkpoint会返回client_state,该状态可能包含
# 迭代计数、RNG状态等,这些是你保存的。
start_iter = client_state.get('iteration', 0) + 1
# ... 从client_state恢复其他自定义状态 ...
else:
print("从头开始训练。")
start_iter = 0
# DeepSpeed处理模型、优化器和调度器状态的恢复。
# 你主要需要恢复你在保存时添加到'client_state'中的自定义状态。
# 而且重要的是,根据恢复的'start_iter'处理数据加载器的跳过。
恢复后,一个好习惯是验证状态是否正确恢复。一个简单的检查方法是,在恢复训练的第一步之后立即记录损失和学习率,并将其与中断前记录的值(如果可用)进行比较。它们应该非常接近,考虑到微小的浮点差异和下一个数据批次的影响。显著的偏差可能表明恢复逻辑存在问题。在启动大规模任务之前,强烈建议在小规模运行中全面测试保存/恢复功能。
通过仔细恢复模型、优化器、调度器、数据加载器位置和其他元数据,你可以确保训练在中断后继续,从而节省时间和计算资源。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造