趋近智
冗余是机器学习模型中常见的情况。尽管数据科学家通常会避免重复编写相同的 Python 代码行,但将高级操作符展开为低级原语时,经常会引入重复的计算。公共子表达式消除 (CSE) 是一项基本的优化过程,它能发现这些冗余,并重构图,使计算结果只进行一次并重用。
在通用编译器的背景下,CSE 会找出在多个位置出现的相同表达式,如 a+b。如果在这些出现之间 a 和 b 都没有变化,编译器会计算一次和,将其存入一个临时变量,并用该变量替换后续的实例。
对于处理计算图 (DAG) 的机器学习编译器,这个过程是针对节点和边操作的,而不是针对源代码行。目的是找出产生完全相同张量结果的子图,移除重复的子图能有效减少前向传播所需的算术操作 (FLOPs)。
编译器不能仅仅因为两个操作使用了相同的操作符类型就认为它们是相同的。为了安全地执行 CSE,编译器必须建立结构等价性。如果两个节点满足以下三个条件,则它们在结构上是等价的:
MatMul 或 ReLU。设想一个模型计算两个分支的情况。两个分支都从归一化输入数据开始。在原始图表示中,这表现为两个不同的 Batch_Norm 节点。由于它们共享相同的输入张量和参数,编译器会将它们合并为一个节点。
识别计算图中的冗余节点。第一个子图中的红色节点表示一个重复操作,通过重用蓝色节点的输出,在第二个子图中被消除。
实现 CSE 需要一种高效的方法来检测重复项,而无需将每个节点与其他所有节点进行比较,否则会导致 O(N2) 的复杂度。相反,编译器通常采用哈希方法来实现接近线性的时间复杂度。
该过程按拓扑顺序遍历图。对于每个访问的节点,编译器根据节点的操作类型、输入索引和属性生成一个哈希签名。这个签名作为所执行计算的唯一标识。
H节点=哈希(操作码,[H输入1,H输入2,...],属性)
编译器维护一个哈希映射(或字典),其中的键是这些签名,值是生成它们的节点的引用。
尽管用户很少会明确编写冗余代码,但重复的子表达式会很自然地产生自更高层次的抽象。
多头注意力 在 Transformer 架构中,查询 (Query)、键 (Key) 和值 (Value) 的投影通常涉及切片一个大张量或应用线性变换。如果特定实现分离了这些操作,但使用相同的权重矩阵进行初始化或重塑逻辑,CSE 可以合并这些设置步骤。
梯度计算 冗余最主要的来源是反向传播。自动微分引擎通过应用链式法则生成梯度。这个过程经常生成镜像前向传播部分或在梯度图的不同分支中重复计算的子图。像 XLA (Accelerated Linear Algebra) 或 TVM 这样的编译器大量依赖 CSE 来清理自动微分系统生成的冗余图。
超参数搜索与架构搜索 使用神经网络架构搜索 (NAS) 时,生成的候选图通常包含重叠的分支。CSE 有助于找出共享组件,使系统能够验证模型的某些部分是否已计算过。
尽管 CSE 减少了操作数量,但它并非总是对性能绝对有利。权衡之处在于内存使用。
当一个子表达式被消除时,剩余节点的结果必须在内存中保持存活,直到最后一个消费者读取它。如果原始的两次出现位置在执行计划中相距较远,优化后的图需要更长时间地存储张量,从而增加峰值内存占用。
设想一个张量 T,在网络的开始处使用,并在网络的末尾再次使用。
高级编译器使用成本模型来决定是否应用 CSE。如果重新计算一个值很廉价(如标量加法),但存储它很昂贵(高寄存器压力或阻塞内存分配),编译器可能会选择执行重实例化,特意跳过 CSE 以节省内存。
为了理解编译器如何看待这种情况,我们可以查看图的文本表示 (IR)。假设我们有一个接受张量 %x 的函数。
原始 IR:
%1 = multiply(%x, 2.0)
%2 = add(%1, 5.0)
%3 = multiply(%x, 2.0) // 重复计算
%4 = subtract(%3, 1.0)
%5 = add(%2, %4)
编译器会追踪 %1 的定义。当它遇到 %3 时,它会发现操作码 multiply、输入 %x 和常数 2.0 与 %1 的定义匹配。
优化后的 IR:
%1 = multiply(%x, 2.0)
%2 = add(%1, 5.0)
// %3 被 %1 替换
%4 = subtract(%1, 1.0)
%5 = add(%2, %4)
操作数量从 5 减少到 4。在拥有数十亿参数的大规模模型中,将这些在重复块中的微小节省累加起来,会带来可衡量的延迟改进。通过系统地找出这些模式,编译器确保硬件只将计算周期用于独特且必需的数学运算。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造