趋近智
state_dictstate_dict在投入大量精力训练机器学习模型后,保存你的工作成果以便后续使用、评估或部署是一个基本步骤。如果你来自 TensorFlow,你可能熟悉诸如 SavedModel 或 HDF5 之类的格式,它们通常会将模型的架构、权重,有时甚至包括训练配置捆绑到一个单一的包中。PyTorch 以略有不同的理念处理模型保存,主要围绕一个名为 state_dict 的对象。
理解 PyTorch 如何管理模型保存,尤其是它的 state_dict,对于有效保存和加载你的工作来说很重要。这种方法提供了灵活性,但与 TensorFlow 更为全面的保存格式相比,它需要一种不同的思维方式。
在 TensorFlow 生态系统中,你通常会遇到两种主要保存模型的方式:
SavedModel 格式:这是 TensorFlow 的标准序列化格式。一个 SavedModel 目录包含完整的 TensorFlow 程序,包括计算图、权重(变量)、资产以及定义模型如何使用的签名(例如,用于 TensorFlow Serving 进行服务)。它被设计为一种语言无关、密封且可恢复的 TensorFlow 模型表示。
HDF5 格式 (.h5 或 .keras):Keras 用户经常使用这种格式。HDF5 文件通常存储模型的架构、权重值以及训练配置(损失、优化器、评估指标)。这是一种方便保存和共享 Keras 模型的方式。
这两种 TensorFlow 格式都旨在提供一个相当完整的模型快照,让你能够加载和使用它,通常无需手头有原始模型创建代码(尽管最好还是保留它)。
state_dictPyTorch 保存模型信息的主要机制围绕着 state_dict。PyTorch 中的 state_dict 本质上是一个 Python 字典对象,它将模型中的每一层映射到其可学习参数(张量),例如权重和偏置。对于优化器(torch.optim.Optimizer),state_dict 包含优化器状态以及使用的超参数信息。
我们来考虑一个简单模型:
import torch
import torch.nn as nn
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.relu1 = nn.ReLU()
self.pool1 = nn.MaxPool2d(2)
self.fc1 = nn.Linear(1440, 50) # 假设在卷积/池化后,输入尺寸会产生 1440 个特征
def forward(self, x):
x = self.pool1(self.relu1(self.conv1(x)))
x = x.view(-1, 1440) # 展平
x = self.fc1(x)
return x
model = SimpleNet()
print(model.state_dict().keys())
运行这段代码会输出类似以下内容:
odict_keys(['conv1.weight', 'conv1.bias', 'fc1.weight', 'fc1.bias'])
state_dict 中的每个键都对应着模型中的一个参数张量。例如,conv1.weight 是 conv1 层的权重张量的键。
最重要的一点是,state_dict 只包含模型的参数。它不存储模型的架构(例如上面 SimpleNet 这样的 Python 类定义)。这是 PyTorch 中有意为之的设计选择。它使保存的状态保持最小化,并依赖你的代码来定义模型结构。这种以 Python 为核心的方法提供了很大的灵活性,因为模型的架构就是 Python 代码,易于修改和检查。
state_dict使用 torch.save() 保存模型的 state_dict 非常直接:
# 假设 'model' 是你 nn.Module 子类的一个实例
PATH = "my_model_state_dict.pt"
torch.save(model.state_dict(), PATH)
PyTorch 保存对象的常见文件扩展名是 .pt(PyTorch)或 .pth(PyTorch 历史上的)。在内部,torch.save() 使用 Python 的 pickle 模块来序列化 state_dict 对象。
要将参数加载回模型中,你首先需要模型类的一个实例。这是因为 state_dict 只包含参数,不包含结构信息。
# 首先,实例化你的模型结构
loaded_model = SimpleNet() # 你必须有 SimpleNet 类定义可用
# 然后,加载 state_dict
loaded_model.load_state_dict(torch.load(PATH))
# 如果你将模型用于推理,请务必调用 model.eval()
# 这会将 dropout 和 batch normalization 等层设置为评估模式
loaded_model.eval()
如果你正在恢复训练,你通常会省略 loaded_model.eval(),并确保模型处于训练模式(loaded_model.train()),这是默认状态。
state_dict 与 TensorFlow 格式的比较保存内容的这种区别对于从 TensorFlow 转向 PyTorch 的开发者来说是核心要点:
| 特点 | TensorFlow (SavedModel, HDF5) | PyTorch (state_dict) |
|---|---|---|
| 保存内容 | 架构、权重、优化器状态(通常)、服务签名(SavedModel) | 主要为可学习参数(权重、偏置)。优化器状态单独保存。 |
| 模型定义 | 通常自包含在保存的文件中。 | 需要模型的 Python 类定义单独可用。 |
| 序列化 | Protocol Buffers (SavedModel)、HDF5。 | Python 的 pickle 用于 state_dict 对象。 |
| 重建 | 直接加载到可用的模型对象中。 | 实例化模型类,然后将 state_dict 加载进去。 |
| 灵活性 | 结构化,适用于部署端点。 | 灵活性高,依赖 Python 代码定义结构。 |
这意味着当你通过 state_dict 共享或归档 PyTorch 模型时,你还必须提供定义模型架构的 Python 代码。仅仅 state_dict 不足以重建模型。
尽管 PyTorch 确实允许你保存整个模型对象(torch.save(model, PATH)),但保存 state_dict 通常是模型保存和共享的推荐做法。这是因为它将学习到的参数与保存时的特定 Python 代码和类结构解耦。如果你重构模型的 Python 文件,但对应 state_dict 的层名称和结构保持不变,你仍然可以加载参数。保存整个模型会序列化整个类,如果文件结构或类定义发生变化,这种方式可能会更脆弱。
这种以 state_dict 为核心的方法让你承担更多管理模型定义代码的责任,但提供了模型逻辑(其架构)与其学习状态(其参数)之间的清晰分离。在后续内容中,我们将探讨保存整个模型与仅保存 state_dict 之间的细微差异,并研究在训练过程中有用的检查点策略。
这部分内容有帮助吗?
state_dict概念以及保存和加载方法。© 2026 ApX Machine Learning用心打造