趋近智
通常需要保存 TensorFlow 模型的完整状态,而不仅仅是其权重。尽管使用 model.save_weights() 仅保存权重对于在相同架构之间迁移已学习的参数或作为轻量级检查点非常有用,但一个全面的方法涉及保存模型的整个状态。这不仅包含权重,还包括模型的架构(层之间的连接方式)及其训练配置(在 model.compile() 期间定义的优化器、损失函数和指标)。保存完整的模型可确保您或他人以后能够准确复原模型状态,用于推理、后续训练或分析,而无需定义模型结构的原始代码。
TensorFlow 通过 Keras 提供了使用 model.save() 方法保存整个模型的直接方式。此方法打包了重建模型所需的一切。
model.save()model.save() 函数将模型序列化到指定路径。默认情况下,它使用 TensorFlow 的标准 SavedModel 格式,这是大多数使用场景(尤其是部署)的推荐方法。或者,您可以明确保存为较旧的 Keras HDF5 格式。
当您调用 model.save() 并提供目录路径时,TensorFlow 会将模型保存为 SavedModel 格式。
# 假设 'model' 是一个已编译且已训练的 Keras 模型
# 将整个模型保存到名为 'my_full_model' 的目录中
model.save('my_full_model')
此命令会创建一个名为 my_full_model 的目录,其中包含以下组件:
saved_model.pb:这个 Protocol Buffer 文件存储了定义模型架构和计算逻辑的 TensorFlow 图。它包含前向传播(推理)以及(如果适用)训练图。variables/ 目录:此目录包含模型的已学习权重(参数),以适合高效加载的格式存储。assets/ 目录:一个可选目录,用于存放模型所需的任何外部文件,例如文本处理层所需的词汇文件。keras_metadata.pb:存储 Keras 特定的元数据,包括优化器状态、损失配置以及在 model.compile() 期间定义的指标。这对于恢复训练非常重要。SavedModel 格式是语言无关的,专为 TensorFlow Serving、TensorFlow Lite(用于移动/边缘设备)和 TensorFlow.js(用于网页浏览器)等服务环境设计。它不仅捕捉了 Keras 模型结构,还包含了底层的 TensorFlow 计算图,使其在部署中很有用。
.h5 或 .keras)您也可以将整个模型保存到一个 HDF5 文件中。这种格式将架构、权重和训练配置(优化器状态、损失函数、指标)打包到一个二进制文件中。尽管作为单个文件很方便,但 SavedModel 格式因其更广泛的兼容性和部署功能而通常更受青睐。
要以 HDF5 格式保存,请提供以 .h5 或较新的 .keras 扩展名(为 Keras V3 兼容性推荐)结尾的文件名。
# 将整个模型保存到单个 HDF5 文件(新格式)
model.save('my_full_model.keras')
# 将整个模型保存到单个 HDF5 文件(旧格式)
# model.save('my_full_model.h5')
这会创建一个单个文件(my_full_model.keras 或 my_full_model.h5),其中包含:
当您使用 model.save() 保存整个模型(无论哪种格式)时,您保存了:
model.compile() 中指定的任何指标。这种完整性很重要,因为它使您能够重新加载模型并从上次停止的地方准确恢复训练,同时保持优化器的动量和学习率调整。它还确保评估加载的模型时使用与原始设置期间定义的相同指标。
使用 tf.keras.models.load_model() 来加载使用 model.save() 保存的模型。TensorFlow 会根据提供的路径自动检测格式(SavedModel 目录或 HDF5 文件)。
import tensorflow as tf
# 从 SavedModel 格式加载模型
loaded_model_sm = tf.keras.models.load_model('my_full_model')
# 从 HDF5 格式加载模型
loaded_model_h5 = tf.keras.models.load_model('my_full_model.keras') # or .h5
# 验证加载模型的结构
loaded_model_sm.summary()
load_model 函数重构模型架构、加载权重并恢复训练配置(优化器、损失函数、指标)。加载的模型已使用保存的配置进行编译,并已准备好使用。
加载完整保存模型的一个重要优点是能够继续训练或直接用于推理。
# 假设您有新数据(new_data, new_labels)或想继续训练
# 加载的模型会记住其优化器状态和编译设置
# 例子:继续训练几个周期
history = loaded_model_sm.fit(training_data, training_labels, epochs=10, validation_data=(val_data, val_labels))
# 例子:对新的、未见过的数据进行预测
predictions = loaded_model_sm.predict(new_unseen_data)
由于优化器状态已保存和加载,训练会有效恢复,并基于之前的学习过程。
保存整个模型通常是最好的方法,当:
与仅保存权重相比,保存整个模型提供了一个更完整且可复现的快照,这对于可靠部署和协作工作流程很重要。请记住,SavedModel 格式是 TensorFlow 提供的用于共享和部署训练好的模型的标准且用途最广的选项。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造