趋近智
state_dict尽管 PyTorch 的 state_dict 和 TorchScript 为在 PyTorch 生态系统内保存和部署模型提供了很好的机制,但有时你需要模型能在不同的机器学习框架中运行。也许你的团队使用 TensorFlow 进行部署,或者你希望使用对通用模型格式有优化支持的特定硬件加速器。在这种情况下,开放神经网络交换 (ONNX) 格式就变得非常有用了。
ONNX 是一个用于表示机器学习模型的开放标准。它定义了一组通用的运算符(模型的基本组成部分,如卷积或矩阵乘法)和一种通用文件格式(.onnx)。目标是实现互操作性:你可以在一个框架中训练模型(如 PyTorch),将其导出为 ONNX 格式,然后在另一个框架(如 TensorFlow、Caffe2、MXNet)或专用的 ONNX 运行时中加载并运行它。
对于熟悉 TensorFlow 生态系统的开发者来说,该生态系统包括用于部署的 SavedModel 和用于移动设备的 TensorFlow Lite,ONNX 可能看起来是多余的一步。然而,它提供了多个优点:
可以把 ONNX 看作你的神经网络模型的通用翻译器。
ONNX 工作流程:一个 PyTorch 模型被导出到 ONNX 格式,然后可以被各种运行时和工具使用,包括 TensorFlow 生态系统中的工具。
PyTorch 使用 torch.onnx.export() 函数内置支持将模型导出到 ONNX 格式。此函数会跟踪你的模型,将其操作转换为 ONNX 图。
让我们看一下 torch.onnx.export() 的常见参数:
model:你的 PyTorch 模型(torch.nn.Module 的实例)。args:模型期望的示例输入元组。此输入用于跟踪模型的执行路径。此示例输入的形状和数据类型很重要。f:ONNX 模型将要保存的路径(例如,"my_model.onnx")。input_names:(可选)为 ONNX 图中的输入节点指定的名称列表。output_names:(可选)为 ONNX 图中的输出节点指定的名称列表。
"* dynamic_axes:(可选)一个字典,指定输入/输出的哪些轴是动态的(例如,批大小、序列长度)。这对模型非常有用。例如,{'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}} 表示 'input' 和 'output' 的第一个维度是动态的,并命名为 'batch_size'。"opset_version:要使用的 ONNX 运算符集版本。ONNX 在发展,新版本增加了对更多运算符的支持。通常最好使用目标部署环境支持的、相对较新的版本。这是一个导出基本 PyTorch 模型的简单例子:
import torch
import torch.nn as nn
import torch.onnx
# 定义一个简单模型
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.linear = nn.Linear(10, 5) # 输入特征数:10,输出特征数:5
def forward(self, x):
return self.linear(x)
# 实例化模型
model = SimpleModel()
model.eval() # 将模型设置为评估模式
# 创建与模型预期输入形状匹配的虚拟输入
# 批大小为 1,10 个输入特征
dummy_input = torch.randn(1, 10)
# 为 ONNX 模型中的输入和输出定义名称,以提高清晰度
input_names = ["input_tensor"]
output_names = ["output_tensor"]
# 导出模型
torch.onnx.export(model,
dummy_input,
"simple_model.onnx",
input_names=input_names,
output_names=output_names,
opset_version=12, # 指定 ONNX opset 版本
dynamic_axes={'input_tensor': {0: 'batch_size'}, # 批大小是动态的
'output_tensor': {0: 'batch_size'}}})
print("模型已导出到 simple_model.onnx")
当你运行此代码时,PyTorch 会使用 dummy_input 执行你的 SimpleModel,记录操作,并将其转换为 ONNX 格式,将结果保存为 simple_model.onnx。dynamic_axes 参数特别重要。没有它,导出的 ONNX 模型将期望固定批大小(本例中为 1)的输入。通过指定 dynamic_axes,我们告诉 ONNX 导出器批处理维度可以变化。
一旦你有了 .onnx 文件,就需要一个运行时来执行它。最常见的是 ONNX Runtime,它是一个用于 ONNX 模型的开源、高性能推理引擎。它是跨平台的,并支持硬件加速。
你可以通过 pip 安装 ONNX Runtime:
pip install onnxruntime
下面是如何加载并运行我们刚刚创建的 simple_model.onnx:
import onnxruntime
import numpy as np
# 创建一个 ONNX Runtime 推理会话
ort_session = onnxruntime.InferenceSession("simple_model.onnx")
# 准备一个示例输入(必须与模型的预期输入匹配,包括批大小)
# 在此次推理运行中使用批大小为 3
input_data = np.random.randn(3, 10).astype(np.float32)
# 从模型获取输入名称(或使用我们在导出时定义的名称)
input_name = ort_session.get_inputs()[0].name
# 运行推理
ort_inputs = {input_name: input_data}
ort_outs = ort_session.run(None, ort_inputs)
# 输出是一个 numpy 数组列表
output_data = ort_outs[0]
print("输入形状:", input_data.shape)
print("输出形状:", output_data.shape)
print("输出数据(第一行):", output_data[0])
此代码片段加载 ONNX 模型,准备一个 NumPy 输入数组,并使用 ONNX Runtime 执行推理。你会注意到我们可以使用与导出时 dummy_input (1)不同的批大小(3),这得益于 dynamic_axes。
作为一名 TensorFlow 开发者,你可能会想 ONNX 如何适应你现有的工作流程。如果你收到一个 ONNX 模型(可能从 PyTorch 导出),你可以将其转换为 TensorFlow 格式(如 SavedModel),以便将其集成到基于 TensorFlow 的部署流程中(例如 TensorFlow Serving)。
onnx-tf 转换器是一个常用的工具:
pip install onnx-tf
然后你可以使用它将 .onnx 文件转换为 TensorFlow SavedModel:
# 假设 onnx_tf 已安装
from onnx_tf.backend import prepare
import onnx
# 加载 ONNX 模型
onnx_model = onnx.load("simple_model.onnx")
# 准备 TensorFlow 表示
tf_rep = prepare(onnx_model)
# 导出为 TensorFlow SavedModel
tf_rep.export_graph("simple_model_tf_savedmodel")
print("ONNX 模型已转换为 TensorFlow SavedModel 格式,位于 'simple_model_tf_savedmodel'")
这会创建一个标准的 TensorFlow SavedModel 目录,然后可以使用 tf.saved_model.load() 加载或通过 TensorFlow Serving 部署。
反之,如果你有 TensorFlow 模型,可以使用 tf2onnx 等工具将其转换为 ONNX(pip install tf2onnx)。这使得你可以将 TensorFlow 模型引入 ONNX 生态系统,可能供 PyTorch 或其他 ONNX 兼容工具使用,尽管本课程侧重于从 PyTorch 到 TensorFlow 的方向。
尽管 ONNX 很强大,但请记住以下几点:
opset_version 中可能没有直接的等效项。你可能需要简化模型或实现自定义 ONNX 运算符(这是一个高级话题)。请务必查看 ONNX 文档以了解支持的运算符。opset_version、ONNX Runtime 版本以及任何转换工具(如 onnx-tf)之间的兼容性。不匹配可能导致错误或意想不到的行为。torch.onnx.export 函数通过使用样本输入追踪模型来工作。如果你的模型具有基于输入数据变化的控制流(这对于旨在导出到 ONNX 的模型来说不太常见,但在 PyTorch 中可能出现),追踪可能无法捕捉到所有执行路径。前面讨论的 TorchScript,有时对于具有复杂控制流的模型,在导出到 ONNX 之前会更稳定。input_names 和 output_names 会使 ONNX 模型在后续使用时更方便,尤其是在使用 ONNX Runtime 或转换工具时。ONNX 在不同机器学习框架之间搭建了一个重要桥梁。对于学习 PyTorch 的 TensorFlow 开发者来说,理解如何将 PyTorch 模型导出到 ONNX,为将这些模型集成到现有的以 TensorFlow 为中心的部署流程中,或运用更广泛的 ONNX 生态系统进行优化和执行提供了机会。
这部分内容有帮助吗?
torch.onnx.export 函数及其参数。© 2026 ApX Machine Learning用心打造