趋近智
内存带宽常是深度学习 (deep learning)推理 (inference)和训练中的主要性能瓶颈。现代GPU和TPU虽有巨量算力 (compute),但其性能常常受限于高带宽内存 (HBM) 与计算单元之间的数据传输速度。算子融合是解决此瓶颈最有效的图级别优化手段。
为理解融合的必要性,考量神经网络 (neural network)中常见的简单操作序列:一个矩阵乘法,后跟一个偏置 (bias)加法和一个ReLU激活函数 (activation function)。
在PyTorch或TensorFlow等标准深度学习 (deep learning)框架中,执行以即时模式或通过图进行,图中每个操作都被视为独立的内核启动。未经优化,硬件会分三个不同的步骤执行此序列:
这种方法效率不高,因为中间张量和是瞬态的。它们仅用于后续步骤,却需要付出高昂的片外内存访问代价。对于加法和ReLU步骤,算术强度(浮点运算 (FLOPs) 与内存访问字节数之比)较低,导致昂贵的计算核心在等待内存时闲置。
算子融合通过将这些相邻节点合并为一个单独的内核来解决此问题。融合后的内核不会将中间结果写入全局内存,而是将数据保存在更快的片上内存(寄存器或L1缓存)中。处理器一次性加载输入,执行矩阵乘法,添加偏置,应用ReLU,并仅写入最终结果。
未融合与融合操作中数据移动的比较。融合方法消除了全局内存的往返访问。
编译器根据所涉及操作的迭代模式对融合机会进行分类。并非所有操作都能轻松融合;这取决于数据如何从输入映射到输出。
这是最简单、最常见的融合形式。它适用于操作将一个输入元素映射到一个输出元素(单射映射)的情况。例子包括张量加法、减法、乘法、ReLU、Sigmoid和Tanh。
由于这些操作的循环结构相同(迭代相同的张量形状),编译器会将循环体合并。
未融合的循环嵌套(伪代码):
// 内核 1:加法
for (int i = 0; i < N; i++) {
temp[i] = A[i] + B[i];
}
// 内核 2:ReLU
for (int i = 0; i < N; i++) {
C[i] = max(0, temp[i]);
}
融合后的循环嵌套:
// 融合内核
for (int i = 0; i < N; i++) {
float t = A[i] + B[i]; // 保存在寄存器中
C[i] = max(0, t);
}
规约涉及降低张量秩的操作,例如sum、max或mean。将逐元素操作融合到后续规约中通常很直接。例如,计算平方和允许在同一循环内,平方操作紧接在累加之前发生。
然而,将规约融合到后续的逐元素操作中更为复杂。如果计算一个sum后跟一个除法(如Softmax),必须先完整计算总和,然后才能对任何元素进行除法。这引入了屏障同步要求,通常需要多趟处理或专用硬件内联函数(如GPU上的warp shuffle)才能在一个内核中高效处理。
将卷积或矩阵乘法 (GEMM) 等复杂算子与后续的逐元素操作(偏置 (bias)加法、激活)融合,具有显著影响。由于GEMM和卷积是计算密集型操作,在计算管道的末尾添加少量标量操作(如加法或钳位 (ReLU))在执行时间上几乎不增加开销。算术单元处理激活,而内存控制器正忙于处理矩阵结果的写回。
编译器不能简单地合并所有相邻节点。它必须遵循数据依赖性,并确保转换保留程序的逻辑。优化器构建一个有向无环图 (DAG),并分析边以确定融合的有效性。
一个主要限制是循环检查。如果节点A输出到节点B,并且我们想融合它们,我们必须确保融合它们不会创建循环或使其他路径失效。例如,如果节点A产生一个同时被节点B和节点C使用的输出,将A融合到B中通常需要为B复制A的计算,或者将A仍然写入内存供C使用。
如果编译器将A融合到B中,但C仍然需要A的输出,编译器必须验证在B内部重新计算A是否比从内存读取更划算。这种策略被称为重物化,它以额外的计算量换取更低的内存带宽。
机器学习 (machine learning)编译器通常采用基于规则或基于成本的方法来识别候选:
Conv2D -> Add -> ReLU。下图显示了将融合应用于标准操作块时潜在的性能提升和内存流量减少。
性能提升通常源于全局内存流量的减少,而非操作数量的减少。
在TVM或MLIR等现代编译器栈中,融合被实现为一个特定的“Pass”(编译阶段)。该阶段遍历中间表示 (IR)。当它识别出可融合的模式时,它将节点序列替换为单个FusedOp节点。这个新节点包含一个复合函数体。
在编译流程的后期,代码生成阶段,这个复合体生成一个紧密的循环嵌套。后端负责中间变量的寄存器分配,确保它们不会触及主内存分配系统。
通过了解融合,您可以编写更易于编译器优化的模型。例如,使用自定义的复合损失函数 (loss function)通常允许编译器生成单个优化内核,而将相同的逻辑写成分散的Python操作可能会让编译器更难识别融合机会。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造