趋近智
尽管统一的低精度计算能带来显著的性能优势,但对于机器学习模型的每个部分来说,这并非总是最佳策略。某些操作可能对量化误差高度敏感,若强制转换为 INT8 或 FP8,可能导致无法接受的精度下降。反之,若在所有地方都保留完整的 FP32 精度,则会抵消潜在的效率提升。混合精度计算提供了一个务实的折衷方案,它有策略地为模型不同部分使用不同的数值格式(例如 FP32、FP16、BF16、FP8、INT8),以平衡精度和性能。
优化这些混合精度模型给编译器带来了独特的挑战。编译器不仅要优化单个操作在其选定精度下的表现,还需有效管理不同精度之间的转换。
混合精度优化的核心思路是,在对精度影响最小的地方积极应用低精度格式,并对敏感操作保留更高精度。常见的需要更高精度的情况通常包括:
编译器的作用是促进这种平衡,要么通过遵守用户指定的精度注解,要么通过自动确定有效的精度配置。
编译器采用多种策略来处理和优化混合精度计算:
在 IR 中的表示: 中间表示需要机制来表示具有不同精度的张量和操作。例如,MLIR 使用其类型系统将特定的浮点或整数类型(例如 tensor<1x256xf16>,tensor<1x1024xi8>)附加到值上。操作本身可能具有指示其操作精度的变体或属性。量化参数(缩放因子,零点)也必须与量化类型关联。
精度分配:
优化精度转换: 数据类型之间的转换(例如,FP32 到 INT8,INT8 到 FP16)会通过量化(Quant)和反量化(DeQuant)操作引入开销。这些通常以逐元素缩放和平移的方式实现。天真的实现会在精度变化的地方插入这些转换,可能造成显著的开销。编译器通过以下方式优化这些转换:
INT8 -> DeQuant -> FP32 Op -> Quant -> INT8 可以被简化,如果 FP32 Op 可以在 INT8 输入上直接实现,并适当处理缩放因子。专用核生成: 编译器的后端必须为处理混合输入/输出或内部计算的操作生成高效代码。这包括:
考虑一个序列 Conv (FP32) -> ReLU (FP32) -> Conv (FP32)。如果我们决定将第二个卷积量化为 INT8,天真的方法会插入转换:
Conv (FP32) -> ReLU (FP32) -> Quant (FP32->INT8) -> DeQuant (INT8->FP32) -> Conv (INT8)
编译器旨在优化这一点。首先,Quant 可能会向后融合到 ReLU 操作中(如果线性,甚至可以融合到前置的 Conv 中)。更重要的是,DeQuant -> Conv(INT8) 模式是融合的首要选择。Conv(INT8) 核可以生成为直接接受 INT8 输入,在乘累加操作期间包含反量化缩放因子(通常累加到 INT32 或 FP32 中),并以累加器精度生成输出。
图示通过融合优化量化和反量化节点。朴素的方法插入显式转换节点,而优化版本则将这些转换融合到相邻的计算操作中。
运行时系统与编译器协同工作。它需要高效管理不同数据类型的内存缓冲区,并处理在不同精度下操作的核之间的执行依赖,可能将它们调度到不同的硬件单元(例如,标准核心与专用矩阵引擎)。
通过结合复杂的编译器分析、转换技术和硬件感知的代码生成,混合精度计算使得开发人员和工具能够取得实际的平衡,从低精度算术中获得显著的性能提升,同时保持对要求较高的机器学习应用所需的精度。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造