趋近智
理解即时(JIT)编译的理论基础,例如跟踪与脚本化或自适应优化,非常必要。但是,要获得实际体会,需要检查这些JIT编译器生成的实际输出。分析中间表示(IR)或编译后的代码,能够具体说明算子合并、常量折叠和专门化等优化是如何应用的,从而将高级Python代码与低级执行计划直接关联起来。这种实际分析对于调试性能问题、验证优化效果以及提高对JIT过程的理解具有很高价值。
在本实践部分,我们将逐步介绍如何使用PyTorch (TorchScript) 和 TensorFlow (XLA) 等常用框架对简单模型片段进行JIT编译,然后分析生成的IR。我们假设您已安装PyTorch和TensorFlow,并拥有一个可用的Python环境。
PyTorch的JIT模块TorchScript提供了两种主要方法将Python代码转换为可优化的图表示:跟踪(torch.jit.trace)和脚本化(torch.jit.script)。我们从跟踪开始。
考虑一个简单的操作序列:一个线性层,接着是ReLU激活,再是一个线性层。
import torch
import torch.nn as nn
class SimpleModel(nn.Module):
def __init__(self, input_features, hidden_features, output_features):
super().__init__()
self.linear1 = nn.Linear(input_features, hidden_features)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(hidden_features, output_features)
def forward(self, x):
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
return x
# 实例化模型并创建示例输入
input_size = 128
hidden_size = 256
output_size = 64
model = SimpleModel(input_size, hidden_size, output_size)
dummy_input = torch.randn(32, input_size) # 批大小为32
# 跟踪模型
traced_model = torch.jit.trace(model, dummy_input)
print("模型跟踪成功。")
跟踪会使用提供的示例输入执行模型,并记录所执行的操作。生成的traced_model包含一个专门针对输入形状(32×128)的静态图表示。我们可以检查这个图:
# 打印TorchScript图IR
print(traced_model.graph)
您将看到类似于以下内容的输出(细节可能因PyTorch版本略有不同):
graph(%self.1 : __torch__.SimpleModel,
%x : Float(32, 128, strides=[128, 1], requires_grad=0, device=cpu)):
%linear1 : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="linear1"](%self.1)
%3 : Float(256, 128, strides=[128, 1], device=cpu) = prim::GetAttr[name="weight"](%linear1)
%4 : Float(256, strides=[1], device=cpu) = prim::GetAttr[name="bias"](%linear1)
%5 : Tensor = aten::linear(%x, %3, %4) # <eval_with_key>.13:10:8
%relu : __torch__.torch.nn.modules.activation.ReLU = prim::GetAttr[name="relu"](%self.1)
%7 : Tensor = aten::relu(%5) # <eval_with_key>.14:8:8
%linear2 : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="linear2"](%self.1)
%9 : Float(64, 256, strides=[256, 1], device=cpu) = prim::GetAttr[name="weight"](%linear2)
%10 : Float(64, strides=[1], device=cpu) = prim::GetAttr[name="bias"](%linear2)
%11 : Tensor = aten::linear(%7, %9, %10) # <eval_with_key>.15:8:8
return (%11)
分析:
%self.1)和输入张量%x及其跟踪到的形状Float(32, 128, ...)。aten::linear、aten::relu)或属性访问(prim::GetAttr)。%3、%4、%9、%10)作为图输入或属性嵌入,通过prim::GetAttr获取。if语句(除非它们在跟踪期间是常量)。它专门针对形状为(32, 128)的输入。如果操作支持,使用不同形状运行traced_model可能有效,但图本身并未明确表示动态形状逻辑。TorchScript会自动应用优化过程,包括合并。虽然简单的线性 -> ReLU合并并不总是保证发生,或在这种高级图转储中不总是直接可见(它通常发生在后续的降低阶段),但我们可以思考如何寻找它。更复杂的模式,尤其是在卷积网络中(Conv-BN-ReLU),是常见的合并目标。
我们可以检查潜在优化之后的图,尽管默认的打印输出通常显示的是初始图。要查看优化的图或执行计划,您可能需要分析工具或内部API(谨慎使用,因为它们可能会更改):
# 使用内部API的示例(可能会更改)
# 这提供了不同的视图,有时会显示优化/合并的操作
# print(torch._C._jit_get_profiling_executor_graph(traced_model.graph))
另外,分析JIT编译前后(例如,使用torch.profiler.profile)的性能特点,可以体现合并等优化的影响,即使IR可视化没有明确显示合并操作。
现在,我们使用脚本化,它直接分析Python字节码。
import torch
import torch.nn as nn
# 假设SimpleModel类如上定义
# 脚本化模型
scripted_model = torch.jit.script(model)
print("模型脚本化成功。")
# 打印TorchScript图IR
print(scripted_model.graph)
对于这个线性模型,输出图看起来会与跟踪图非常相似,因为它不包含跟踪会遗漏的Python控制流。但是,如果模型包含数据依赖的控制流(例如,基于张量值的if语句),脚本化图将使用prim::If节点明确表示此控制流,而跟踪图将只包含跟踪期间所取的路径。
分析(脚本化 vs 跟踪):
if、for),使其对具有动态行为的模型更具灵活性。跟踪仅捕获针对特定跟踪输入执行的操作。TensorFlow使用XLA(加速线性代数)作为其优化编译器,通常通过tf.function(jit_compile=True)调用。XLA在其自己的IR,即HLO(高级优化器IR)上进行操作。
import tensorflow as tf
# 定义一个简单函数
@tf.function(jit_compile=True)
def simple_computation(x, w1, b1, w2, b2):
y = tf.matmul(x, w1) + b1
y = tf.nn.relu(y)
z = tf.matmul(y, w2) + b2
return z
# 创建一些示例输入
input_shape = (32, 128)
hidden_shape = 256
output_shape = 64
x_in = tf.random.normal(input_shape)
w1_in = tf.random.normal((input_shape[1], hidden_shape))
b1_in = tf.random.normal((hidden_shape,))
w2_in = tf.random.normal((hidden_shape, output_shape))
b2_in = tf.random.normal((output_shape,))
# 执行JIT编译的函数
result = simple_computation(x_in, w1_in, b1_in, w2_in, b2_in)
print("XLA JIT函数执行。")
# print(result.numpy()) # 查看输出
检查XLA生成的HLO通常涉及使用环境变量或TensorFlow日志/分析工具。一种常用方法是使用TensorBoard进行分析,它可以可视化XLA编译步骤和生成的HLO图。另一种方法是在运行脚本之前设置环境变量:
export TF_XLA_FLAGS="--tf_xla_dump_to=/path/to/dump/folder"
# 现在运行您的Python脚本
python your_script.py
此命令指示XLA将其编译过程的各个阶段,包括HLO图(通常是.hlo.dot文件或文本proto格式),转储到指定目录中。然后您可以检查这些文件。
例如,.dot文件可以使用Graphviz工具(dot -Tpng input.dot -o output.png)进行可视化。HLO图可能看起来像这样(简化版):
simple_computation函数的简化HLO图。节点表示操作(dot表示矩阵乘法,add表示加法,maximum表示relu),边表示数据流,参数表示输入。
分析:
dot(用于矩阵乘法)、add、maximum(常用于ReLU)、broadcast和参数。fusion节点。例如,MatMul + BiasAdd + ReLU序列是合并为一个单一优化内核的主要候选,这将在优化后的HLO图中显示为一个节点。这种实践分析演示了如何访问和解释ML JIT编译器生成的中间表示。通过检查TorchScript图或XLA HLO转储,您可以:
分析JIT输出是从事ML框架的性能工程师的一项基本技能。它弥合了高级模型代码与优化后的、硬件特定的执行计划之间的差距,为实现最高性能提供了重要信息。使用这些方法来研究您自己模型的JIT行为,并考察本章中讨论的不同JIT策略的影响。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造