代数简化针对计算图的数学结构。这种优化技术与其他技术相辅相成,例如运算符合并(它侧重于合并操作以减少开销并提高局部性)以及布局转换(它优化数据移动)。在进阶层面,代数简化远超简单的常量折叠(例如 $2 + 2 = 4$)或恒等消除(例如 $x * 1 = x$)。相反,它应用源自线性代数、微积分以及机器学习运算符特定属性的复杂数学规则,以完全简化或消除计算。高级代数简化的主要目标包括:降低计算成本: 直接消除冗余计算,或用更经济的等效操作替换高成本操作。简化图结构: 减少节点和边的数量,这可以使图更容易被后续优化过程(例如合并或调度)分析和优化。促成进一步优化: 简化可能为运算符合并提供新的机会,或呈现适合专用核实现的模式。常见简化模式进阶机器学习编译器采用模式匹配引擎(通常是前面提到的图重写系统的一部分)来识别与已知代数恒等式对应的子图。以下是机器学习图中经常使用的简化规则示例:线性代数恒等式:转置: 消除双重转置: $$(A^T)^T \rightarrow A$$。在可能的情况下,将转置与矩阵乘法结合,尽管这通常与布局偏好有很大关联: $$(A B)^T = B^T A^T$$。单位元和零元: 简化涉及单位矩阵 ($I$) 或零张量 ($Z$) 的操作: $$A + Z \rightarrow A$$, $$A \times I \rightarrow A$$, $$A \times Z \rightarrow Z$$。请注意,构造单位张量或零张量本身可能产生开销,因此消除它们是有益的。结合律/分配律: 重新关联像矩阵乘法这样的操作($$(A B) C \rightarrow A (B C)$$)有时可以提高性能或促成合并,但由于潜在的浮点精度差异需要谨慎。应用分配律($$A (B + C) \rightarrow A B + A C$$)则不那么常见,因为它通常会增加操作计数,但如果 $$A B$$ 或 $$A C$$ 可以在其他地方合并或简化,它可能有用。逐元素操作与广播:简化逐元素操作链:exp(log(x)) $\rightarrow$ x(在有效域内),scale(scale(x, a), b) $\rightarrow$ scale(x, a*b)。常量折叠:x + Constant(a) + Constant(b) $\rightarrow$ x + Constant(a+b)。简化广播操作:一个 broadcast_to(x, shape) 后跟一个与已具有 shape 的另一个张量 y 进行的逐元素操作,可能允许将广播合并到逐元素核中。重排可交换的逐元素操作以组合常量或促成合并。归约操作:组合归约:reduce_sum(reduce_sum(x, axis=0), axis=1) 可能根据轴的不同合并为一个单一归约。与逐元素操作的分配属性:reduce_sum(x + y) $\rightarrow$ reduce_sum(x) + reduce_sum(y)(如果形状和归约轴对齐)。归一化和激活层:折叠操作:一个 batch_norm 层紧随其后的 scale 和 shift(仿射变换)通常可以在数学上折叠成一个具有调整参数的等效 batch_norm。简化激活:relu(Constant(-5)) $\rightarrow$ Constant(0)。sigmoid(very_large_positive_constant) $\rightarrow$ Constant(1.0)。恒等序列:dropout(x, rate=0) $\rightarrow$ x。通过图重写实现这些简化通常作为重写规则实现在编译器的图优化框架内。每条规则定义:模式: 要匹配的子图结构(例如,输入是另一个 Transpose 节点的 Transpose 节点)。约束: 必须满足的条件(例如,数据类型必须匹配,张量形状必须兼容)。替换: 用于替换匹配模式的更简单子图结构。编译器迭代地应用这些规则,直到无法找到更多简化(即达到不动点)。应用顺序有时很重要,特别是当某些简化可能启用或禁用其他简化时。考虑将缩放和偏置操作折叠到前置卷积中的常见模式:digraph Before { rankdir=LR; graph [bgcolor="transparent"]; node [shape=box, style=filled, color="#a5d8ff", fontname="Arial"]; edge [color="#495057", fontname="Arial"]; subgraph cluster_0 { label = "简化前"; bgcolor="#e9ecef"; style=rounded; X -> Conv [label=" 输入"]; Conv -> AddBias [label=" 特征"]; AddBias -> Scale [label=" 偏置特征"]; Scale -> Output [label=" 缩放特征"]; } X [label="数据 (N,C,H,W)"]; Conv [label="Conv2D\n权重 W\n偏置 B_conv"]; AddBias [label="加法\n偏置 b_add"]; Scale [label="乘法\n缩放 s"]; Output [label="结果"]; }一个涉及卷积、添加偏置张量和缩放的序列。代数简化过程可以识别 AddBias 和 Scale 操作构成一个仿射变换,可以直接合并到 Conv2D 的参数中。该规则将匹配 Scale(Add(Conv(x, W, B_conv), b_add), s),并用一个使用修改后的权重 ($W'$) 和偏置 ($B'_{conv}$) 的单一 Conv2D 操作来替换它。变换结果为: $s \times (Conv(x, W, B_{conv}) + b_{add}) = Conv(x, s \times W, s \times B_{conv} + s \times b_{add})$因此,新参数是 $W' = s \times W$ 和 $B'{conv} = s \times B{conv} + s \times b_{add}$。digraph After { rankdir=LR; graph [bgcolor="transparent"]; node [shape=box, style=filled, color="#96f2d7", fontname="Arial"]; edge [color="#495057", fontname="Arial"]; subgraph cluster_1 { label = "简化后"; bgcolor="#e9ecef"; style=rounded; X -> Conv_prime [label=" 输入"]; Conv_prime -> Output [label=" 结果"]; } X [label="数据 (N,C,H,W)"]; Conv_prime [label="Conv2D\n权重 W' = s*W\n偏置 B'_conv = s*B_conv + s*b_add"]; Output [label="结果"]; }将加偏置和缩放操作折叠到卷积中后的图。这种简化减少了两个图操作,并可能避免中间张量的具体化,从而降低内存使用和减少核启动次数,或为修改后的卷积实现更高效的合并核。考量与挑战尽管功能强大,代数简化需要谨慎实现:浮点语义: 由于浮点运算的非结合性,激进地重排或简化浮点操作可能会改变数值输出。编译器通常有标志来控制是否允许“不安全”(可能改变精度)的简化。数值稳定性: 某些数学上等效的形式可能比其他形式的数值稳定性差(例如,易受灾难性抵消影响)。重写规则通常应偏好稳定的形式。复杂性: 实现和验证大量复杂代数规则具有挑战性。确保在不同运算符组合和数据类型上的正确性需要广泛的测试。与其他过程的交互: 代数简化的有效性可能取决于前面的过程(例如,常量折叠),并影响随后的过程(例如,合并)。整体优化流水线设计必须考虑这些交互。通过应用这些高级代数技术,编译器可以显著精简计算图,为更有效的低级代码生成铺平道路,并为机器学习工作负载实现更高的性能。