趋近智
大师班
在为长时间运行的训练任务实现检查点时,首要的考量是保存过程如何与正在进行的训练计算相互影响。训练在写入检查点时是完全暂停,还是可以同时进行?这引出了两种主要方式:同步检查点和异步检查点。选择它们时需要权衡简单性、一致性和性能开销。
同步检查点是最直接的方式。当检查点触发时(例如,在达到一定的训练步数或经过一段时间后),训练过程会明确地暂停所有计算。然后,它收集所需的状态组成部分:模型参数、优化器状态、学习率调度器状态、当前的轮次或步数,以及可能的数据加载器迭代器状态。一旦所有状态收集完毕,它们就会被序列化并写入持久化存储(如分布式文件系统或云存储)。只有在写入操作成功完成后,训练过程才会恢复计算。
在分布式训练环境下,同步检查点需要所有参与工作节点之间的协调。通常,在保存之前会使用屏障同步,以确保所有工作节点都到达同一点。一个工作节点(通常是0号节点)可能会被指定从其他节点收集状态,或者每个工作节点保存自己部分的状态。保存之后可能会使用另一个屏障,以确保所有工作节点都等到检查点完全写入后再继续。
优点:
缺点:
以下是一个在分布式环境中,使用类似PyTorch语法的训练循环中同步检查点的表示:
# 假设已初始化 torch.distributed
def save_synchronous_checkpoint(
rank, world_size, model, optimizer, scheduler, step, path
):
# 确保所有进程在保存前都到达这一点
if world_size > 1:
torch.distributed.barrier()
if rank == 0: # 0号节点处理合并状态的保存
print(
f"Rank {rank}: Starting synchronous checkpoint save at step {step}..."
)
# 在实际场景中,状态可能会从其他节点收集
# 或者每个节点保存自己部分(例如,使用DeepSpeed/FSDP辅助函数)
state = {
'step': step,
'model_state_dict': model.state_dict(),
# 或者 model.module.state_dict()
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict()
# 可能添加数据加载器状态、随机数生成器状态等
}
torch.save(state, path)
print(f"Rank {rank}: Finished synchronous checkpoint save to {path}.")
else:
# 其他节点等待0号节点完成保存
pass
# 确保0号节点上的保存完成,然后所有节点才能继续
if world_size > 1:
torch.distributed.barrier()
# --- 训练循环内部 ---
model.train()
for step, batch in enumerate(data_loader):
# 前向传播、反向传播、优化器步进...
outputs = model(batch['input_ids'])
loss = calculate_loss(outputs, batch['labels'])
loss.backward()
optimizer.step()
scheduler.step()
optimizer.zero_grad()
# 定期检查点
if step % checkpoint_interval == 0 and step > 0:
checkpoint_path = f"/path/to/checkpoints/step_{step}.pt"
# --- 阻塞保存操作 ---
save_synchronous_checkpoint(
rank,
world_size,
model,
optimizer,
scheduler,
step,
checkpoint_path
)
# --- 训练仅在保存完成后恢复 ---
# ... 循环的其余部分(日志记录、评估等)
下图说明了同步检查点的阻塞特性。
训练在所有节点同步保存检查点时完全停止。
异步检查点旨在减轻同步保存的性能开销。其核心思路是将写入检查点这一计算开销大的I/O操作与主训练循环分离。
当检查点触发时,主训练进程会启动保存操作,但不会等待其完成。这通常通过以下方式实现:
优点:
缺点:
实现异步检查点通常涉及使用线程或多进程库。
import threading
import torch
import time
import os
# 假设 torch.distributed 已初始化 (rank, world_size)
# 实际模型、优化器等的占位符
class DummyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 10)
def forward(self, x): return self.linear(x)
def state_dict(self): return {'param': torch.randn(10, 10)}
model = DummyModel()
optimizer = torch.optim.Adam(model.parameters())
scheduler = torch.optim.lr_scheduler.StepLR(
optimizer, step_size=10, gamma=0.1
)
# 用于跟踪后台保存线程的全局变量
checkpoint_thread = None
def background_save_task(state, path):
"""由后台线程执行的函数。"""
print(f"Background Saver: Starting async save to {path}...")
try:
# 模拟慢速I/O
time.sleep(5) # 模拟保存时间
# 如果目录不存在则创建
os.makedirs(os.path.dirname(path), exist_ok=True)
torch.save(state, path)
print(f"Background Saver: Finished async save to {path}.")
except Exception as e:
print(f"Background Saver: Error during checkpointing: {e}")
def save_asynchronous_checkpoint(
rank, world_size, model, optimizer, scheduler, step, path
):
global checkpoint_thread
# 确保前一个后台保存完成,然后才能开始新的保存
if checkpoint_thread is not None and checkpoint_thread.is_alive():
print(f"Rank {rank}: Waiting for previous async checkpoint to finish...")
checkpoint_thread.join() # 等待前一个线程完成
if rank == 0: # 0号节点启动并管理保存线程
print(
f"Rank {rank}: Initiating asynchronous checkpoint save at step {step}..."
)
# --- 快速复制状态 ---
# 如有必要,请使用deepcopy,state_dict()通常返回副本/视图
state = {
'step': step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict()
}
# --- 启动后台线程 ---
checkpoint_thread = threading.Thread(
target=background_save_task, args=(state, path)
)
checkpoint_thread.start()
print(f"Rank {rank}: Background save launched. Training continues.")
else:
# 其他节点可能需要最少的协调,例如,确保它们
# 如果一致性非常重要,不要进展太远,
# 但通常它们会继续训练。
# 如果需要复制状态前的严格一致性,
# 则在复制状态*之前*可能需要一个屏障。
pass
# --- 训练在所有节点上立即继续 ---
# 注意:此处没有屏障,允许重叠
# --- 训练循环内部 ---
step = 0
checkpoint_interval = 5 # 示例:每5步检查点
max_steps = 20
print("启动模拟训练循环...")
while step < max_steps:
step += 1
print(f"主循环:训练步 {step}")
# 模拟训练工作
time.sleep(0.5)
# model(...), loss.backward(), optimizer.step()...
if step % checkpoint_interval == 0:
# --- 非阻塞保存启动 ---
checkpoint_path = f"/tmp/async_checkpoints/step_{step}.pt"
save_asynchronous_checkpoint(
0, 1, model, optimizer, scheduler, step, checkpoint_path
)
# 为简单起见,假设节点0,world_size为1
# 等待循环退出后最后一个检查点线程完成
if checkpoint_thread is not None and checkpoint_thread.is_alive():
print("主循环:等待最终检查点完成...")
checkpoint_thread.join()
print("模拟训练循环完成。")
下图说明了异步检查点如何将I/O与计算重叠。
主训练线程仅短暂暂停以复制状态,然后继续计算,而实际保存则在后台线程中进行。
最佳方式取决于具体的训练设置和优先级:
实践中,对于检查点时间可能很长(数分钟或更久)的超大型模型,异步检查点通常更受青睐,以最大限度地利用昂贵的GPU资源,尽管增加了实现复杂性。细致的实现和测试是确保异步保存过程可靠性所必需的。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造