数学等价性允许编译器重写计算图,使其成为更高效的形式,而不会改变最终结果。代数简化是一种优化方法,旨在减少模型本身的算术复杂度。此过程运用初等代数和线性代数的公理,以找出多余的操作,用开销更小的指令替换开销大的指令,并重新安排计算以提升性能。前端框架常因为使用高级抽象而生成次优的图。用户编写的代码逻辑上可能合理,却引入了不必要的转置,或等同于将张量乘以单位矩阵。代数简化阶段的目标是在张量调度这项繁重工作开始之前,将图标准化并消除这些低效之处。强度消减强度消减的做法是用功能相同但硬件开销更小的指令替代计算开销大的操作。在深度学习加速器的背景下,特定算术逻辑单元(ALU)之间的开销差异可能很大。超越函数(如 pow、exp、log)和除法操作通常比简单的乘法或加法消耗更多的时钟周期和芯片面积。以幂运算为例,它常用于方差计算或欧几里得距离度量。一个通用的幂运算内核处理任意浮点指数,需要复杂的逻辑。然而,如果指数是一个小整数,编译器可以重写该节点:$$y = x^2 \rightarrow y = x \cdot x$$$$y = x^3 \rightarrow y = x \cdot x \cdot x$$这种转换将对特殊功能单元的调用替换为简单的乘法指令,这些指令通常可向量化,并在GPU上具有更高的吞吐量。类似地,除以常数 $C$ 可以转换为乘以倒数 $1/C$。由于在大多数指令集架构(ISA)上,乘法通常比除法快,这会带来显著的加速。$$y = \frac{x}{C} \rightarrow y = x \cdot (C^{-1})$$编译器在编译阶段预先计算 $C^{-1}$(常数折叠),因此运行时只执行乘法操作。恒等和冗余消除图中常包含一些操作,它们没有计算目的,是自动微分或模块化模型设计的产物。消除这些节点能减少内核启动开销和内存带宽占用。常见的恒等转换包括:加法恒等式: $x + 0 \rightarrow x$乘法恒等式: $x \cdot 1 \rightarrow x$幂等性: $\text{max}(x, x) \rightarrow x$在张量代数中,转置抵消是常见的优化目标。转置操作本质上是交换张量的步长。连续两次转置会撤销步长交换,使张量恢复到其原始布局。如果一个图中包含两个互为逆操作的连续转置,那么这两个节点可以完全移除。$$ (A^T)^T \rightarrow A $$这种逻辑也适用于切片和拼接。如果一个张量被切片,然后立即以相同顺序拼接回来,那么这个序列是一个空操作。下图展示了一个包含冗余线性代数操作的图的结构简化。digraph G { rankdir=TB; node [shape=box, style="filled", fontname="Arial", fontsize=10, color="#dee2e6"]; edge [color="#adb5bd"]; subgraph cluster_0 { label="原始图"; style="dashed"; color="#adb5bd"; fontcolor="#868e96"; node_a [label="输入张量A", fillcolor="#eebefa"]; node_t1 [label="转置 (0, 1)", fillcolor="#d0bfff"]; node_t2 [label="转置 (0, 1)", fillcolor="#d0bfff"]; node_add [label="加0", fillcolor="#ffc9c9"]; node_out [label="输出", fillcolor="#eebefa"]; node_a -> node_t1; node_t1 -> node_t2; node_t2 -> node_add; node_add -> node_out; } subgraph cluster_1 { label="优化后的图"; style="dashed"; color="#adb5bd"; fontcolor="#868e96"; node_a_opt [label="输入张量A", fillcolor="#eebefa"]; node_out_opt [label="输出", fillcolor="#eebefa"]; node_a_opt -> node_out_opt [label="零拷贝直通"]; } }冗余子图的优化。转置对相互抵消,加零是恒等操作,从而允许输入直接流向输出,无需计算。GEMM中的结合律重排最具影响力的代数简化之一是矩阵乘法(GEMM)的重新排序。矩阵乘法符合结合律,即 $(AB)C = A(BC)$。然而,这两种求值顺序的计算开销很少相同。开销很大程度上取决于所涉及矩阵的维度。我们将计算复杂度定义为将 $M \times K$ 大小的矩阵乘以 $K \times N$ 大小的矩阵,大致为 $O(M \cdot N \cdot K)$。考虑一个包含三个矩阵的链,其维度如下:$A: 10 \times 100$$B: 100 \times 10$$C: 10 \times 100$顺序 1: $(A \times B) \times C$计算 $T = A \times B$。结果 $T$ 是 $10 \times 10$。开销:$10 \cdot 10 \cdot 100 = 10,000$ FLOPs。计算 $R = T \times C$。结果 $R$ 是 $10 \times 100$。开销:$10 \cdot 100 \cdot 10 = 10,000$ FLOPs。总计: 20,000 FLOPs。顺序 2: $A \times (B \times C)$计算 $T = B \times C$。结果 $T$ 是 $100 \times 100$。开销:$100 \cdot 100 \cdot 10 = 100,000$ FLOPs。计算 $R = A \times T$。结果 $R$ 是 $10 \times 100$。开销:$10 \cdot 100 \cdot 100 = 100,000$ FLOPs。总计: 200,000 FLOPs。在这个特定场景下,第一种顺序比第二种快一个数量级。编译器可以分析IR中张量的静态形状,并引入一个重排阶段来减少总浮点运算(FLOPs)。{ "layout": { "title": "GEMM重排的计算开销", "xaxis": { "title": "求值顺序" }, "yaxis": { "title": "FLOPs(越低越好)" }, "barmode": "group", "width": 500, "height": 300, "margin": {"l": 50, "r": 50, "t": 50, "b": 50} }, "data": [ { "x": ["(A x B) x C", "A x (B x C)"], "y": [20000, 200000], "type": "bar", "marker": { "color": ["#37b24d", "#fa5252"] }, "text": ["20k FLOPs", "200k FLOPs"], "textposition": "auto" } ] }矩阵乘法不同结合分组所需浮点运算的比较。优化后的顺序显著降低了计算负担。广播和布局简化现代深度学习框架非常依赖广播语义(例如 NumPy 风格的广播)。尽管对开发者来说很方便,广播可能会掩盖底层的线性代数。简化阶段分析广播模式以发现缩减机会。例如,归一化层中的常见模式是张量的规约,然后广播回原始形状进行除法。如果后续操作有效遮掩或切分此张量,编译器可能证明完整的广播是不必要的。此外,代数特性允许布局转换的传播。像 ReLU 或 Tanh 这样的操作是逐元素的,并且与数据布局排列可交换。这使得编译器可以将转置操作推过激活函数,以便将它们与其他对布局敏感的操作(如卷积)分组,这可能实现前面提到的转置的融合或消除。$$ \text{ReLU}(\text{转置}(A)) \Leftrightarrow \text{转置}(\text{ReLU}(A)) $$浮点精度与安全性需要指出的是,代数简化在实数 $\mathbb{R}$ 领域在数学上是严密的,但在浮点运算领域并非总是精确的。由于舍入误差,浮点加法并非严格符合结合律。$$ (a + b) + c \neq a + (b + c) $$当对不同数量级的数字求和时,这种差异会变得明显。深度学习编译器常提供“快速数学”标志。启用后,编译器将浮点数视为实数,允许积极的重排和简化。当精度要求高时(例如在科学计算或高精度训练的梯度累积中),这些代数阶段必须仅限于值安全的转换。在推理场景下,模型通常对微小噪声具有鲁棒性,编译器通常默认采用积极的代数简化以最大化吞吐量。编译器阶段的实现者必须明确分类重写规则,依据它们是否保持位精度或引入数值偏差。