深度学习模型的训练通常耗时,根据模型复杂度和数据集大小,可能需要数小时甚至数天。每次因系统中断、后续微调或仅为预测而停止时都从头开始训练是不现实的。因此,保存和加载模型检查点变得非常重要。检查点记录了训练过程在特定时刻的状态,方便您日后恢复。本节介绍如何有效保存和加载 PyTorch 模型和训练状态的必要组成部分。应该保存什么?保存检查点时,您需要根据目的决定需要哪些信息。通常,您至少会保存模型的参数。如果您打算恢复训练,还应保存优化器的状态,以及当前周期数和最新验证损失等其他元数据。PyTorch 模型有一个内部状态字典,通过 model.state_dict() 访问,其中包含模型各层的所有学习参数(权重和偏置)。这是保存模型学习信息建议的方式。为什么保存 state_dict 而不是整个模型对象(例如 torch.save(model, PATH))?保存 state_dict 更具弹性,也更不容易出问题。对整个模型对象进行序列化保存会存储保存时使用的特定代码结构。如果您之后重构或更改模型类定义,加载序列化对象可能会失败或导致意外行为。仅保存状态字典将学习参数与代码结构分离,使加载更稳定。同样,Adam 或 SGD 等优化器也有内部状态(例如,动量缓冲区、自适应学习率),这些状态在训练过程中会变化。为了精确地恢复训练,您应该使用 optimizer.state_dict() 保存优化器的状态。使用 torch.save 保存检查点PyTorch 使用 torch.save() 来序列化和保存对象。要保存检查点,您通常会创建一个字典,其中包含模型的状态字典、优化器的状态字典以及任何其他相关信息,然后保存这个字典。以下是在训练循环中保存检查点的常见模式:# 假设 model, optimizer, epoch, loss 已定义 # PATH = "path/to/your/checkpoint.pth" # 定义您的保存路径 checkpoint = { 'epoch': epoch + 1, # 保存下一个要开始的周期数 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss, # 或者可以是验证损失 # 添加任何其他您想保存的指标或信息 # 'validation_accuracy': val_acc, } torch.save(checkpoint, PATH) print(f"检查点已在周期 {epoch} 保存到 {PATH}")您可以定期(例如,每 10 个周期)保存检查点,或在模型在验证集上取得新最佳表现时保存。使用 torch.load 和 load_state_dict 加载检查点要加载检查点,您首先使用 torch.load() 从文件中反序列化保存的字典。然后,您需要将状态字典加载回您的模型和优化器实例中。注意: 在加载模型和优化器状态之前,您必须先创建它们的实例。load_state_dict() 方法将参数加载到一个现有对象中;它不会重新创建对象本身。加载用于推理如果您只需要模型进行预测(推理),并且不打算恢复训练,通常只需加载 model_state_dict。# 首先,实例化您的模型结构 model = YourModelClass(*args, **kwargs) # 定义您保存的检查点路径 PATH = "path/to/your/checkpoint.pth" # 加载检查点字典 checkpoint = torch.load(PATH) # 从检查点加载模型状态字典 model.load_state_dict(checkpoint['model_state_dict']) # 将模型设置为评估模式 model.eval() # 现在模型已准备好进行推理 # with torch.no_grad(): # outputs = model(inputs)设置 model.eval() 很关键,因为它会禁用 Dropout 等层,并使用运行统计数据对批量归一化层进行归一化,这是推理时的正确操作。加载以恢复训练如果您想从上次中断的地方继续训练,您需要加载模型和优化器的状态,并获取其他已保存的元数据,例如周期数。# 首先实例化模型和优化器 model = YourModelClass(*args, **kwargs) optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) # 或您选择的优化器 # 定义路径 PATH = "path/to/your/checkpoint.pth" start_epoch = 0 best_loss = float('inf') # 示例:初始化最佳损失 # 检查检查点是否存在以进行加载 if os.path.exists(PATH): checkpoint = torch.load(PATH) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) start_epoch = checkpoint['epoch'] best_loss = checkpoint['loss'] # 加载之前的损失 print(f"检查点已加载。从周期 {start_epoch} 继续训练") # 将模型设置为训练模式 model.train() # 现在您可以继续训练循环,从 start_epoch 开始 # for epoch in range(start_epoch, num_epochs): # # ... 训练步骤 ...设置 model.train() 可确保 Dropout 和批量归一化等层在训练期间表现正常。处理 CPU/GPU 设备映射有时,您可能会保存一个在 GPU 上训练的模型,然后需要在只有 CPU 的机器上加载它,反之亦然。默认情况下,torch.load() 会尝试将张量加载到它们保存时所在的设备上。如果该设备不可用,这可能会导致错误。为了处理这种情况,您可以使用 torch.load() 中的 map_location 参数。# 将在 GPU 上训练的模型加载到 CPU checkpoint = torch.load(PATH, map_location=torch.device('cpu')) model.load_state_dict(checkpoint['model_state_dict']) # 将任何模型加载到当前可用设备(如果 GPU 可用则使用 GPU,否则使用 CPU) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") checkpoint = torch.load(PATH, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) # 记住也要将您的模型移到设备上 model.to(device)保存和加载检查点是深度学习工作流程中很普通的一部分。通过熟练掌握使用 torch.save、torch.load 和 load_state_dict 的这些方法,您可以保证训练进程安全,模型可以重复使用,并且训练过程面对中断时表现稳定。请记住,为了获得最大的适应性和稳定性,要保存模型和优化器的 state_dict,以及相关的元数据。