趋近智
requires_grad)backward()).grad)torch.nn 搭建模型torch.nn.Module 基类torch.nn 损失)torch.optim)torch.utils.data.Datasettorchvision.transforms)torch.utils.data.DataLoader深度学习模型的训练通常耗时,根据模型复杂度和数据集大小,可能需要数小时甚至数天。每次因系统中断、后续微调或仅为预测而停止时都从头开始训练是不现实的。因此,保存和加载模型检查点变得非常重要。
检查点记录了训练过程在特定时刻的状态,方便您日后恢复。有效保存和加载 PyTorch 模型和训练状态是必要的组成部分。
保存检查点时,您需要根据目的决定需要哪些信息。通常,您至少会保存模型的参数。如果您打算恢复训练,还应保存优化器的状态,以及当前周期数和最新验证损失等其他元数据。
PyTorch 模型有一个内部状态字典,通过 model.state_dict() 访问,其中包含模型各层的所有学习参数(权重和偏置)。这是保存模型学习信息建议的方式。
为什么保存 state_dict 而不是整个模型对象(例如 torch.save(model, PATH))?保存 state_dict 更具弹性,也更不容易出问题。对整个模型对象进行序列化保存会存储保存时使用的特定代码结构。如果您之后重构或更改模型类定义,加载序列化对象可能会失败或导致意外行为。仅保存状态字典将学习参数与代码结构分离,使加载更稳定。
同样,Adam 或 SGD 等优化器也有内部状态(例如,动量缓冲区、自适应学习率),这些状态在训练过程中会变化。为了精确地恢复训练,您应该使用 optimizer.state_dict() 保存优化器的状态。
torch.save 保存检查点PyTorch 使用 torch.save() 来序列化和保存对象。要保存检查点,您通常会创建一个字典,其中包含模型的状态字典、优化器的状态字典以及任何其他相关信息,然后保存这个字典。
以下是在训练循环中保存检查点的常见模式:
# 假设 model, optimizer, epoch, loss 已定义
# PATH = "path/to/your/checkpoint.pth" # 定义您的保存路径
checkpoint = {
'epoch': epoch + 1, # 保存下一个要开始的周期数
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss, # 或者可以是验证损失
# 添加任何其他您想保存的指标或信息
# 'validation_accuracy': val_acc,
}
torch.save(checkpoint, PATH)
print(f"检查点已在周期 {epoch} 保存到 {PATH}")
您可以定期(例如,每 10 个周期)保存检查点,或在模型在验证集上取得新最佳表现时保存。
torch.load 和 load_state_dict 加载检查点要加载检查点,您首先使用 torch.load() 从文件中反序列化保存的字典。然后,您需要将状态字典加载回您的模型和优化器实例中。
注意: 在加载模型和优化器状态之前,您必须先创建它们的实例。load_state_dict() 方法将参数加载到一个现有对象中;它不会重新创建对象本身。
如果您只需要模型进行预测(推理),并且不打算恢复训练,通常只需加载 model_state_dict。
# 首先,实例化您的模型结构
model = YourModelClass(*args, **kwargs)
# 定义您保存的检查点路径
PATH = "path/to/your/checkpoint.pth"
# 加载检查点字典
checkpoint = torch.load(PATH)
# 从检查点加载模型状态字典
model.load_state_dict(checkpoint['model_state_dict'])
# 将模型设置为评估模式
model.eval()
# 现在模型已准备好进行推理
# with torch.no_grad():
# outputs = model(inputs)
设置 model.eval() 很关键,因为它会禁用 Dropout 等层,并使用运行统计数据对批量归一化层进行归一化,这是推理时的正确操作。
如果您想从上次中断的地方继续训练,您需要加载模型和优化器的状态,并获取其他已保存的元数据,例如周期数。
# 首先实例化模型和优化器
model = YourModelClass(*args, **kwargs)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) # 或您选择的优化器
# 定义路径
PATH = "path/to/your/checkpoint.pth"
start_epoch = 0
best_loss = float('inf') # 示例:初始化最佳损失
# 检查检查点是否存在以进行加载
if os.path.exists(PATH):
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']
best_loss = checkpoint['loss'] # 加载之前的损失
print(f"检查点已加载。从周期 {start_epoch} 继续训练")
# 将模型设置为训练模式
model.train()
# 现在您可以继续训练循环,从 start_epoch 开始
# for epoch in range(start_epoch, num_epochs):
# # ... 训练步骤 ...
设置 model.train() 可确保 Dropout 和批量归一化等层在训练期间表现正常。
有时,您可能会保存一个在 GPU 上训练的模型,然后需要在只有 CPU 的机器上加载它,反之亦然。默认情况下,torch.load() 会尝试将张量加载到它们保存时所在的设备上。如果该设备不可用,这可能会导致错误。
为了处理这种情况,您可以使用 torch.load() 中的 map_location 参数。
# 将在 GPU 上训练的模型加载到 CPU
checkpoint = torch.load(PATH, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model_state_dict'])
# 将任何模型加载到当前可用设备(如果 GPU 可用则使用 GPU,否则使用 CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint = torch.load(PATH, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
# 记住也要将您的模型移到设备上
model.to(device)
保存和加载检查点是深度学习工作流程中很普通的一部分。通过熟练掌握使用 torch.save、torch.load 和 load_state_dict 的这些方法,您可以保证训练进程安全,模型可以重复使用,并且训练过程面对中断时表现稳定。请记住,为了获得最大的适应性和稳定性,要保存模型和优化器的 state_dict,以及相关的元数据。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造