PyTorch 和 TensorFlow 等机器学习框架以两种不同的模式运行:即时执行和图执行。即时模式会立即执行操作,当 Python 解释器遇到它们时。这使得调试变得简单,编码直观,但这给优化带来了显著的障碍。编译器无法看到模型的“全貌”;它一次只看到一个操作。为了进行全局优化,例如层融合或内存访问重排序,系统必须捕获整个计算序列到一个被称为计算图的静态结构中。这种将命令式 Python 代码转换为声明式图表示的过程是机器学习编译流程中的第一步。即时执行的瓶颈在标准 Python 执行中,解释器会与框架的调度器交互,对于每一个操作。如果你有一个循环,其中矩阵乘法运行一千次,Python 解释器必须发出调用一千次。考虑以下数学运算:$$z = \text{ReLU}(x \cdot y + b)$$在即时执行环境中,执行流程如下所示:Python 读取 x * y。Python 调用 C++ 内核进行乘法运算。内核返回一个临时张量。Python 读取 ... + b。Python 调用 C++ 内核进行加法运算。内核返回另一个临时张量。Python 调用 ReLU 函数。这引入了“调度开销”。在 Python 和底层 C++ 运行时之间切换所花费的时间有时会超过实际计算所花费的时间,特别是对于小型操作符。为了消除这种开销并实现操作符融合,我们必须将这些操作捕获到中间表示 (IR) 中,这种 IR 将模型逻辑与 Python 解释器分离。图捕获机制从动态框架中捕获图主要有两种方式:追踪和脚本化。追踪追踪是在操作执行时记录它们的过程。要追踪模型,你需要将一个虚拟输入(通常称为示例输入)通过网络。框架不仅仅是计算结果;它会记录输入张量上发生的每一个数学操作。该机制通过使用代理对象来工作。当框架看到一个代理张量进入算术操作时,它会在图中创建一个表示该操作的节点,而不是立即执行它(或者在执行它的同时)。digraph G { rankdir=TB; node [shape=box, style=filled, fontname="Arial", fontsize=12, color="#dee2e6"]; edge [fontname="Arial", fontsize=10, color="#868e96"]; subgraph cluster_0 { label = "Python 环境"; style = filled; color = "#f8f9fa"; input [label="虚拟输入\n(张量)", fillcolor="#a5d8ff"]; model [label="模型定义\n(Python 代码)", fillcolor="#e9ecef"]; tracer [label="追踪器 / JIT 编译器", fillcolor="#b197fc"]; } subgraph cluster_1 { label = "捕获的输出"; style = filled; color = "#f8f9fa"; graph_ir [label="静态计算\n图 (IR)", fillcolor="#69db7c"]; } input -> model [label="输入到"]; model -> tracer [label="执行流"]; tracer -> graph_ir [label="记录操作"]; tracer -> input [label="跟踪数据依赖", dir=back, style=dashed]; }追踪过程中数据流的示意图,Python 执行被记录到静态图中。追踪很高效,因为它不需要解析 Python 源代码。它只是观察发生了什么。如果你的模型调用了一个最终执行 PyTorch 张量操作的第三方库,追踪器会捕获它,只要数据流与输入张量保持关联。脚本化脚本化直接分析 Python 源代码(通常检查抽象语法树或 AST)。它将 Python 控制结构,如 if、for 和 while,直接转换为其对应的图表示。虽然脚本化保留了控制流逻辑,但它通常比追踪更脆弱,因为它要求 Python 代码严格遵守编译器能理解的语言子集。追踪机制的实践我们来看看追踪如何处理一个简单的线性变换后接激活操作。这是密集层中常见的模式。import torch import torch.nn as nn class SimpleLayer(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(10, 10) self.relu = nn.ReLU() def forward(self, x): return self.relu(self.linear(x)) # Instantiate model and dummy input model = SimpleLayer() dummy_input = torch.randn(1, 10) # Trace the model traced_graph = torch.jit.trace(model, dummy_input) # The traced_graph is now independent of the Python class print(traced_graph.graph)当调用 torch.jit.trace 时,编译器运行 forward 方法。它观察到矩阵乘法(来自 linear)、加法(偏置)和 ReLU 操作。它输出一个在中间表示 (IR) 中大致如下所示的图:%1 = matmul(%x, %W)%2 = add(%1, %b)%3 = relu(%2)return %3这个 IR 不包含 Python 开销。它是数据依赖的纯粹描述。这个图现在可以被序列化,加载到 C++ 运行时中,或者传递给像 TVM 或 XLA 这样的低级编译器进行硬件映射。控制流的处理追踪的一个显著局限是它无法高效捕获动态控制流。因为追踪记录的是使用虚拟输入进行特定运行时所采取的路径,它有效地“固化”了逻辑决策。考虑一个带有条件语句的函数:def risky_function(x): if x.sum() > 0: return x * 2 else: return x - 1如果你使用 x.sum() > 0 的输入来追踪这个函数,追踪器只会记录乘法路径。结果图将如下所示:$$y = x \cdot 2$$if 语句被完全移除。如果你随后用 x.sum() < 0 的输入运行这个编译后的图,它仍将执行乘法路径,导致不正确的结果。对于具有静态结构的模型(例如标准 ResNet 或 Transformer,其结构不根据输入数据改变),追踪非常高效。对于需要动态逻辑的模型(例如递归网络或具有可变边界的循环),脚本化或专门的控制流操作符是必需的。图捕获的优势一旦图被捕获,编译器就能获得程序的全局视图。这使得在即时模式下无法实现的一些优化成为可能:死代码消除: 如果图的一部分对输出没有贡献,它可以被剪除。公共子表达式消除: 多次出现的冗余计算可以只计算一次并重复使用。代数简化: 数学操作可以被简化(例如,组合转置)。操作符融合: 这是最重要的优化。编译器可以将我们之前例子中的乘法、加法和 ReLU 合并为一个单独的内核启动,大幅减少内存带宽使用。下面的图表说明了通过捕获和融合操作实现的内核启动次数减少。{"layout": {"title": {"text": "图捕获对内核启动次数的影响", "font": {"family": "Arial", "size": 16, "color": "#495057"}}, "xaxis": {"title": {"text": "执行模式", "font": {"family": "Arial", "size": 12, "color": "#868e96"}}, "showgrid": false}, "yaxis": {"title": {"text": "GPU 内核启动次数", "font": {"family": "Arial", "size": 12, "color": "#868e96"}}, "showgrid": true, "gridcolor": "#e9ecef"}, "plot_bgcolor": "white", "margin": {"t": 50, "l": 50, "r": 30, "b": 50}, "height": 350, "width": 600, "showlegend": false}, "data": [{"type": "bar", "x": ["即时执行", "图捕获(融合后)"], "y": [15, 4], "marker": {"color": ["#adb5bd", "#339af0"]}, "text": ["高调度开销", "已优化"], "textposition": "auto"}]}标准即时执行与对一系列张量操作捕获并融合后的图,在 GPU 内核启动次数上的比较。从捕获到优化图捕获是高级框架代码和低级硬件实现之间的桥梁。通过成功追踪模型,你可以将一个依赖 Python 的程序转换为可移植、可优化的中间表示。这种表示作为编译堆栈后续阶段的输入,在这些阶段,我们应用图级变换和循环级优化来最大限度地提高硬件效率。