趋近智
保存单个训练过程的状态很简单,正如前面一节所述。然而,大型语言模型训练几乎总是涉及多个计算节点和设备并行工作。这种分布式特性给检查点保存过程带来了很大的复杂性。仅仅让每个工作节点独立保存其状态是不够的;需要协调以确保所保存的集体状态代表整个训练任务的一个有效且一致的快照。如果缺乏这种协调,恢复训练可能会导致行为偏差或结果不正确。
主要困难源于对一致性的要求。所有参与的进程(通常称为‘秩’,或‘rank’)必须保存其训练状态部分,且这些部分对应于计算中的同一点,通常是在特定训练步骤结束时。如果不同的秩在稍微不同的时间保存,例如一个秩完成了其梯度更新,而另一个秩仍在计算梯度,那么生成的检查点将不一致,很可能无法使用。
确保一致性的最基本方法是同步。在启动保存操作之前,所有秩都必须同步,以确保它们在训练循环中到达了相同的逻辑点。在 PyTorch 分布式数据并行(DDP)等框架中,这通常通过使用屏障(barrier)等集体通信操作来实现。
import torch
import torch.distributed as dist
import os
# 假设 setup_distributed() 初始化进程组
# setup_distributed()
def save_checkpoint_distributed(
model, optimizer, scheduler, epoch, step, checkpoint_dir
):
"""保存检查点,在所有秩之间协调。"""
# 确保所有秩都已准备好保存
dist.barrier()
# 指定一个秩(通常是秩0)来处理非分片保存
if dist.get_rank() == 0:
# 如果检查点目录不存在则创建它
os.makedirs(checkpoint_dir, exist_ok=True)
# 准备状态字典
# 注意:对于 DDP,应保存 model.module 以剥离 DDP 包装器
state = {
'epoch': epoch,
'step': step,
'model_state_dict': model.module.state_dict(), # 保存底层模型
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
# 添加任何其他必要的状态(例如,RNG 状态、数据加载器状态)
}
# 定义检查点路径
checkpoint_filename = f"checkpoint_epoch_{epoch}_step_{step}.pt"
checkpoint_path = os.path.join(checkpoint_dir, checkpoint_filename)
# 保存状态字典
torch.save(state, checkpoint_path)
print(f"Rank 0: 已将检查点保存到 {checkpoint_path}")
# 确保所有秩等待秩 0 完成保存后再继续
dist.barrier()
# 在训练循环中的使用示例(简化)
# model = ... # 您的 DDP 包装模型
# optimizer = ...
# scheduler = ...
# checkpoint_dir = "/path/to/checkpoints"
# current_epoch = 1
# current_step = 5000
# save_checkpoint_distributed(
# model, optimizer, scheduler, current_epoch, current_step, checkpoint_dir
# )
在上面的示例中,dist.barrier() 作为一个同步点。第一个屏障确保所有秩在秩 0 开始保存之前暂停。然后,秩 0 保存必要的状态字典。重要地,对于使用 DDP 包装的模型,我们保存 model.module.state_dict() 来存储原始模型的参数 (parameter),而不是 DDP 包装器本身。第二个屏障确保没有秩会继续下一个训练步骤,直到秩 0 成功完成保存操作。这可以防止在状态被保存时某些秩可能开始修改状态的竞态条件。
虽然这种秩 0 保存方法有效,但它有其局限性,尤其是在大规模情况下。将整个模型状态、优化器状态以及可能很大的梯度收集到单个秩上,会形成网络瓶颈,并对秩 0 提出大量内存要求。此外,保存过程本身也通过秩 0 变为串行。
一种更具可扩展性的方法,在您使用 ZeRO(零冗余优化器)等内存优化技术或张量/管道并行时特别适合,是保存分片检查点。在分片检查点中,每个秩仅保存其在整个训练状态中的一部分。
DeepSpeed 和 Megatron-LM 等库提供了更高级别的 API,它们简化了管理分片检查点的许多复杂性。它们处理同步,并确保每个秩保存与其在并行配置中角色相符的正确状态。
# 使用类似 DeepSpeed API 的示例(实际 API 可能有所不同)
# 假设 'model_engine' 是 DeepSpeed 包装的模型、优化器等。
# DeepSpeed 通常为检查点使用标签
checkpoint_tag = f"epoch_{epoch}_step_{step}"
checkpoint_dir = "/path/to/sharded/checkpoints"
# DeepSpeed 的 save_checkpoint 在内部处理分片和同步
# 它保存模型状态、优化器状态、调度器状态等。
# 每个秩将自己的分片写入目录。
save_status = model_engine.save_checkpoint(checkpoint_dir, checkpoint_tag)
if save_status:
print(
f"秩 {dist.get_rank()}:成功保存了 "
f"分片检查点 {checkpoint_tag}"
)
else:
print(f"秩 {dist.get_rank()}:未能保存分片检查点 {checkpoint_tag}")
# 这里不需要显式屏障,因为它由 DeepSpeed 函数管理。
当使用分片检查点时,checkpoint_dir 将包含多个文件,每个文件代表一个来自不同秩的分片,或包含元数据。加载过程也必须了解这种分片结构,以便在所有秩上正确重构全局状态。
在分布式设置中加载检查点也需要协调。
state_dict 后会自动处理模型参数 (parameter)的分布,但优化器状态可能需要根据具体设置进行手动处理。# 秩 0 保存方法的加载示例
def load_checkpoint_distributed(model, optimizer, scheduler, checkpoint_path):
"""将秩 0 保存的检查点加载到所有秩上。"""
# 确保检查点路径存在
if not os.path.exists(checkpoint_path):
print(f"警告:检查点路径 {checkpoint_path} 不存在。 "
f"将从头开始。")
return 0, 0 # 返回起始轮次/步数
# 首先将检查点加载到 CPU,以避免秩 0 上的 GPU 内存峰值
rank_device = 'cuda:%d' % dist.get_rank()
map_location = {'cuda:%d' % 0: rank_device} # 映射到当前秩的设备
checkpoint = torch.load(checkpoint_path, map_location=map_location)
# 加载模型状态(请记住,对于 DDP 使用 model.module)
model.module.load_state_dict(checkpoint['model_state_dict'])
# 加载优化器和调度器状态
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
# 加载轮次和步数
epoch = checkpoint['epoch']
step = checkpoint['step']
print(f"秩 {dist.get_rank()}:已从 "
f"{checkpoint_path} 加载检查点(轮次 {epoch},步数 {step})")
# 确保所有秩都已加载后再继续
dist.barrier()
return epoch, step
# checkpoint_to_load = "/path/to/checkpoints/checkpoint_epoch_1_step_5000.pt"
# start_epoch, start_step = load_checkpoint_distributed(model, optimizer, scheduler, checkpoint_to_load)
model_engine.load_checkpoint)会处理为每个秩读取适当的分片并重构分布式状态。从用户的角度来看,这个过程简单很多,因为库管理了复杂性。# 使用类似 DeepSpeed API 的示例
# checkpoint_dir = "/path/to/sharded/checkpoints"
# checkpoint_tag = f"epoch_{epoch}_step_{step}" # 保存时使用的标签
# DeepSpeed 的 load_checkpoint 处理读取分片和分发状态
load_path, client_state = model_engine.load_checkpoint(
checkpoint_dir, checkpoint_tag
)
if load_path:
print(f"秩 {dist.get_rank()}:成功从 {load_path} 加载了分片 "
f"检查点 {checkpoint_tag}")
# client_state 通常包含轮次、步数等信息。
start_epoch = client_state.get('epoch', 0)
start_step = client_state.get('step', 0)
else:
print(f"秩 {dist.get_rank()}:未能找到检查点 "
f"{checkpoint_tag},将从头开始。")
start_epoch, start_step = 0, 0
正确处理分布式检查点对于可靠的大规模模型训练非常重要。虽然手动实现需要仔细的同步和状态管理,但使用 DeepSpeed 或 Megatron-LM 等分布式训练框架中的功能,通常通过自动化分片和同步,提供了一种更可靠和可扩展的方案。请务必始终彻底测试您的检查点保存和加载过程,以确保它们在您的特定分布式设置中正常工作。
这部分内容有帮助吗?
dist.barrier()等集合通信操作和分布式数据并行原则的详细信息。© 2026 ApX Machine LearningAI伦理与透明度•