趋近智
PyTorch 和 TensorFlow 等机器学习框架以两种不同的模式运行:即时执行和图执行。即时模式会立即执行操作,当 Python 解释器遇到它们时。这使得调试变得简单,编码直观,但这给优化带来了显著的障碍。编译器无法看到模型的“全貌”;它一次只看到一个操作。
为了进行全局优化,例如层融合或内存访问重排序,系统必须捕获整个计算序列到一个被称为计算图的静态结构中。这种将命令式 Python 代码转换为声明式图表示的过程是机器学习编译流程中的第一步。
在标准 Python 执行中,解释器会与框架的调度器交互,对于每一个操作。如果你有一个循环,其中矩阵乘法运行一千次,Python 解释器必须发出调用一千次。
考虑以下数学运算:
z=ReLU(x⋅y+b)
在即时执行环境中,执行流程如下所示:
x * y。... + b。这引入了“调度开销”。在 Python 和底层 C++ 运行时之间切换所花费的时间有时会超过实际计算所花费的时间,特别是对于小型操作符。为了消除这种开销并实现操作符融合,我们必须将这些操作捕获到中间表示 (IR) 中,这种 IR 将模型逻辑与 Python 解释器分离。
从动态框架中捕获图主要有两种方式:追踪和脚本化。
追踪是在操作执行时记录它们的过程。要追踪模型,你需要将一个虚拟输入(通常称为示例输入)通过网络。框架不仅仅是计算结果;它会记录输入张量上发生的每一个数学操作。
该机制通过使用代理对象来工作。当框架看到一个代理张量进入算术操作时,它会在图中创建一个表示该操作的节点,而不是立即执行它(或者在执行它的同时)。
追踪过程中数据流的示意图,Python 执行被记录到静态图中。
追踪很高效,因为它不需要解析 Python 源代码。它只是观察发生了什么。如果你的模型调用了一个最终执行 PyTorch 张量操作的第三方库,追踪器会捕获它,只要数据流与输入张量保持关联。
脚本化直接分析 Python 源代码(通常检查抽象语法树或 AST)。它将 Python 控制结构,如 if、for 和 while,直接转换为其对应的图表示。虽然脚本化保留了控制流逻辑,但它通常比追踪更脆弱,因为它要求 Python 代码严格遵守编译器能理解的语言子集。
我们来看看追踪如何处理一个简单的线性变换后接激活操作。这是密集层中常见的模式。
import torch
import torch.nn as nn
class SimpleLayer(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 10)
self.relu = nn.ReLU()
def forward(self, x):
return self.relu(self.linear(x))
# Instantiate model and dummy input
model = SimpleLayer()
dummy_input = torch.randn(1, 10)
# Trace the model
traced_graph = torch.jit.trace(model, dummy_input)
# The traced_graph is now independent of the Python class
print(traced_graph.graph)
当调用 torch.jit.trace 时,编译器运行 forward 方法。它观察到矩阵乘法(来自 linear)、加法(偏置)和 ReLU 操作。它输出一个在中间表示 (IR) 中大致如下所示的图:
%1 = matmul(%x, %W)%2 = add(%1, %b)%3 = relu(%2)return %3这个 IR 不包含 Python 开销。它是数据依赖的纯粹描述。这个图现在可以被序列化,加载到 C++ 运行时中,或者传递给像 TVM 或 XLA 这样的低级编译器进行硬件映射。
追踪的一个显著局限是它无法高效捕获动态控制流。因为追踪记录的是使用虚拟输入进行特定运行时所采取的路径,它有效地“固化”了逻辑决策。
考虑一个带有条件语句的函数:
def risky_function(x):
if x.sum() > 0:
return x * 2
else:
return x - 1
如果你使用 x.sum() > 0 的输入来追踪这个函数,追踪器只会记录乘法路径。结果图将如下所示:
y=x⋅2
if 语句被完全移除。如果你随后用 x.sum() < 0 的输入运行这个编译后的图,它仍将执行乘法路径,导致不正确的结果。
对于具有静态结构的模型(例如标准 ResNet 或 Transformer,其结构不根据输入数据改变),追踪非常高效。对于需要动态逻辑的模型(例如递归网络或具有可变边界的循环),脚本化或专门的控制流操作符是必需的。
一旦图被捕获,编译器就能获得程序的全局视图。这使得在即时模式下无法实现的一些优化成为可能:
下面的图表说明了通过捕获和融合操作实现的内核启动次数减少。
标准即时执行与对一系列张量操作捕获并融合后的图,在 GPU 内核启动次数上的比较。
图捕获是高级框架代码和低级硬件实现之间的桥梁。通过成功追踪模型,你可以将一个依赖 Python 的程序转换为可移植、可优化的中间表示。这种表示作为编译堆栈后续阶段的输入,在这些阶段,我们应用图级变换和循环级优化来最大限度地提高硬件效率。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造