趋近智
state_dict当您将模型从 PyTorch 灵活的 Python 开发训练环境转为生产部署时,通常需要一种方法将模型序列化为独立于 Python、可优化且可移植的格式。这就是 TorchScript 的作用。可以将其看作一座桥梁,将您的动态 PyTorch 模型转换为更静态、图表般的表示形式,以便在各种环境中运行,包括 C++ 运行时、移动设备或 Python 可能不够理想或性能不足的服务器。
对于熟悉 TensorFlow 的人来说,TorchScript 的作用类似于 tf.function 将 Python 代码转换为 TensorFlow 图,以及 SavedModel 如何将此图与权重一起打包以供部署。尽管 PyTorch 的即时执行(定义即运行)因其即时性和 Python 式的风格而非常适合研究和实验,但生产环境通常能从静态图表示所提供的优化和可移植性中获益。
TorchScript 是 PyTorch 模型的一种中间表示 (IR)。它允许您创建可序列化和可优化的模型版本,这些版本不依赖于 Python 运行时。这意味着您可以在 Python 中定义模型,然后将其转换为 TorchScript 以便:
TorchScript 本质上捕获了模型的计算图,使其更适用于这些训练后操作。
有两种主要方法可以将您的 PyTorch nn.Module 转换为 TorchScript 模块:追踪和脚本化。
从 Python 定义的
nn.Module到可部署的 TorchScript 模型,可通过追踪或脚本化完成。
torch.jit.trace 进行追踪追踪的工作原理是使用一些示例输入执行您的模型一次。PyTorch 会记录所有在这些输入流经模型时执行的操作,从而有效地“追踪”一条路径。这条被记录的操作序列随后形成了 TorchScript 图。
import torch
import torch.nn as nn
# 一个简单的演示模型
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.linear1 = nn.Linear(10, 20)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(20, 5)
def forward(self, x):
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
return x
# 实例化模型
model = SimpleNet()
model.eval() # 设置为评估模式
# 提供一个示例输入张量
example_input = torch.randn(1, 10) # 批大小 1,10 个特征
# 追踪模型
try:
traced_model = torch.jit.trace(model, example_input)
print("模型追踪成功!")
# 您可以检查“代码”(图表的 Python 类似表示)
# print(traced_model.code)
except Exception as e:
print(f"追踪过程中出错:{e}")
# 保存追踪的模型
traced_model_path = "traced_simple_net.pt"
traced_model.save(traced_model_path)
print(f"追踪的模型已保存到 {traced_model_path}")
# 加载追踪的模型
loaded_traced_model = torch.jit.load(traced_model_path)
print("追踪的模型加载成功。")
# 您现在可以使用加载的追踪模型进行推理
output = loaded_traced_model(example_input)
print("从加载的追踪模型输出:", output.shape)
追踪的优点:
追踪的局限性:
if 语句或循环,追踪只会记录您提供的特定 example_input 所采用的路径。其他路径可能会被遗漏。例如,如果 if 条件取决于 x.sum() > 0 且您的示例输入使其为真,那么 else 分支将不会是追踪图的一部分。torch.jit.script 进行脚本化另一方面,脚本化涉及 TorchScript 编译器直接分析您的 Python 源代码(特别是 forward 方法及其调用的任何函数或模块)。它将此 Python 代码转换为 TorchScript 中间表示。这种方法对于具有动态控制流的模型更有效。
您可以将 @torch.jit.script 作为装饰器应用于函数或整个 nn.Module。对于 nn.Module,它通常会编译 forward 方法以及您显式装饰或从 forward 中调用的任何其他方法。
import torch
import torch.nn as nn
class ScriptableNet(nn.Module):
def __init__(self, D_in, H, D_out):
super(ScriptableNet, self).__init__()
self.linear1 = nn.Linear(D_in, H)
self.linear2 = nn.Linear(H, D_out)
def forward(self, x: torch.Tensor, use_relu: bool) -> torch.Tensor:
h_relu = torch.relu(self.linear1(x))
# 控制流示例
if use_relu:
y_pred = self.linear2(h_relu)
else:
y_pred = self.linear2(self.linear1(x)) # 第一层输出不应用 ReLU
return y_pred
# 实例化模型
script_model_instance = ScriptableNet(10, 20, 5)
script_model_instance.eval()
# 脚本化模型实例
try:
scripted_model = torch.jit.script(script_model_instance)
print("模型脚本化成功!")
# 您可以检查生成的“代码”
# print(scripted_model.code)
except Exception as e:
print(f"脚本化过程中出错:{e}")
# 保存脚本化的模型
scripted_model_path = "scripted_net.pt"
scripted_model.save(scripted_model_path)
print(f"脚本化的模型已保存到 {scripted_model_path}")
# 加载脚本化的模型
loaded_scripted_model = torch.jit.load(scripted_model_path)
print("脚本化的模型加载成功。")
# 使用不同的控制流路径进行测试
example_input = torch.randn(1, 10)
output_with_relu = loaded_scripted_model(example_input, True)
output_without_relu = loaded_scripted_model(example_input, False)
print("带 ReLU 路径的输出:", output_with_relu.shape)
print("不带 ReLU 路径的输出:", output_without_relu.shape)
脚本化的优点:
if 语句、循环和其他 Python 结构通常会得到保留。脚本化的注意事项:
x: torch.Tensor)可以显著帮助 TorchScript 编译器理解您的代码,这通常是一个好的做法。您也可以脚本化单独的函数:
@torch.jit.script
def custom_activation(input_tensor: torch.Tensor) -> torch.Tensor:
if input_tensor.mean() > 0:
return torch.relu(input_tensor)
else:
return torch.sigmoid(input_tensor)
example_tensor = torch.randn(5)
print(custom_activation(example_tensor))
example_tensor_neg_mean = torch.tensor([-1.0, -2.0, -0.5])
print(custom_activation(example_tensor_neg_mean))
一旦您的模型采用 TorchScript 格式,您将获得多项优势:
traced_model.save() 或 scripted_model.save() 保存的 .pt 文件包含模型的架构(以图的形式)及其参数(权重和偏置)。这个单一文件可以在其他 Python 环境中加载,或者更重要地,在使用 LibTorch 的 C++ 等非 Python 环境中加载。这对于部署不希望或无法提供完整 Python 堆栈的模型来说非常重要。nn.Sequential 或具有非常线性、静态的操作流,torch.jit.trace 通常是获取 TorchScript 模型最快捷、最简单的方法。if x.sum() > 0: ... else: ...),或者如果追踪未能捕获完整行为,torch.jit.script 是更可靠的方法。TorchScript 是一种功能强大的机制,可将您的 PyTorch 模型从研发阶段带入生产环境。通过将模型转换为这种可序列化和优化的格式,您可以使它们适用于更广泛的部署情况,从而摆脱 Python 解释器的限制,并实现与各种应用环境的集成。这是重要的一步,类似于 TensorFlow 开发者如何使用 SavedModels 来打包其训练好的模型以用于推理引擎和服务平台。在接下来的部分中,我们将简要介绍 ONNX 以实现更广泛的互操作性,以及 TorchServe 用于部署这些模型。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造