Alright, let's put theory into practice. We've discussed how quantized operations are represented and within compiler IRs. Now, we'll focus on the critical step of lowering: transforming these high-level quantized operations into sequences of lower-level, often standard integer, arithmetic operations that explicitly handle scaling, zero-point adjustments, and type conversions. This process is fundamental for generating executable code, especially for hardware that lacks direct support for arbitrarily scaled quantized types but possesses efficient integer arithmetic units.
Imagine we have a high-level quantized operation in our IR, perhaps representing a 2D convolution or matrix multiplication. It looks something like this (using a simplified, MLIR-inspired notation):
// High-Level Representation
%input_q = "quant.cast"(%input_fp32) : tensor<1xHxWxCinxf32> -> tensor<1xHxWxCin x !quant.uniform<i8:...>>
%weight_q = "quant.cast"(%weight_fp32) : tensor<CoutxKhxKwxCinxf32> -> tensor<CoutxKhxKwxCin x !quant.uniform<i8:...>>
%bias_q = "quant.cast"(%bias_fp32) : tensor<Coutxf32> -> tensor<Cout x !quant.uniform<i32:...>> // Bias often INT32
// Represents the entire quantized convolution including output scaling
%output_q = "quant.conv2d"(%input_q, %weight_q, %bias_q) {
strides = [...], padding = [...],
output_scale = ..., output_zero_point = ...
} : (..., ..., ...) -> tensor<1xHoWxWoxCout x !quant.uniform<i8:...>>
This quant.conv2d
operation encapsulates the core computation along with the necessary requantization logic to produce the final INT8 output. Our goal is to lower this into operations available in standard dialects (like arith
for arithmetic, memref
for memory access in MLIR) or into target-specific intrinsics.
Recall the affine quantization formula: r=S×(q−Z), where r is the real value, q is the quantized integer value, S is the scale, and Z is the zero point.
For a convolution (or matrix multiplication), the core computation involves multiply-accumulate operations. Let's consider a single output element calculation, which is a sum of products: Outputreal=∑(Inputreal×Weightreal)+Biasreal.
Substituting the quantization formula:
Sout(Outputq−Zout)≈∑[Sin(Inputq−Zin)×Sw(Weightq−Zw)]+Sbias(Biasq−Zbias)Our target is to compute Outputq using integer arithmetic. Rearranging the equation involves:
The full expansion can get complex:
Sout(Outputq−Zout)≈SinSw∑(Inputq−Zin)(Weightq−Zw)+Sbias(Biasq−Zbias) Sout(Outputq−Zout)≈SinSw∑[InputqWeightq−InputqZw−ZinWeightq+ZinZw]+Sbias(Biasq−Zbias)The compiler's task during lowering is to generate efficient integer code that computes the right-hand side and then solves for Outputq.
A common strategy involves these steps, which would be reflected in the generated lower-level IR:
// Lowered IR Snippet 1: Integer MatMul/Conv Accumulation %acc_i32 = arith.constant 0 : i32 // Loop over reduction dimensions (e.g., input channels, kernel spatial dims) affine.for %k = 0 to K { %in_val_i8 = memref.load %input_q[...] : memref<...x!quant.uniformi8:...> %wt_val_i8 = memref.load %weight_q[...] : memref<...x!quant.uniformi8:...> // Extend i8 to i32 for accumulation %in_val_i32 = arith.extsi %in_val_i8 : i8 to i32 %wt_val_i32 = arith.extsi %wt_val_i8 : i8 to i32 // Integer multiplication %mul_i32 = arith.muli %in_val_i32, %wt_val_i32 : i32 // Accumulate %acc_i32 = arith.addi %acc_i32, %mul_i32 : i32 } ```
(Input_q - Z_in) * (Weight_q - Z_w)
. This adds subtractions inside the loop.
// Lowered IR Snippet 2: Zero-Point Correction (Post-Accumulation style) // Assume %sum_inputs_i32, %sum_weights_i32, %reduction_size are pre-calculated or available %zp_in_i32 = arith.constant ... : i32 // Z_in %zp_wt_i32 = arith.constant ... : i32 // Z_w
%correction1 = arith.muli %sum_inputs_i32, %zp_wt_i32 : i32
%correction2 = arith.muli %sum_weights_i32, %zp_in_i32 : i32
%correction3 = arith.muli %reduction_size, %zp_in_i32 : i32
%correction3 = arith.muli %correction3, %zp_wt_i32 : i32 // Term Z_in * Z_w * K
%acc_i32 = arith.subi %acc_i32, %correction1 : i32
%acc_i32 = arith.subi %acc_i32, %correction2 : i32
%acc_i32 = arith.addi %acc_i32, %correction3 : i32
```
3. Bias Addition: Add the quantized bias term (Biasq). Remember, the bias scale Sbias must ideally equal Sin×Sw for direct INT32 addition. If not, the bias must be rescaled before addition.
mlir // Lowered IR Snippet 3: Bias Addition %bias_val_i32 = memref.load %bias_q[...] : memref<...x!quant.uniform<i32:...>> // Assume bias scale matches S_in * S_w, otherwise rescaling needed here %acc_i32 = arith.addi %acc_i32, %bias_val_i32 : i32
// Lowered IR Snippet 4: Requantization Scaling %requant_mult_i32 = arith.constant ... : i32 // M0 %requant_shift_i32 = arith.constant ... : i32 // N
// Perform fixed-point multiplication: (acc * M0) >> N
// Often requires widening to i64 to avoid overflow during multiplication
%acc_i64 = arith.extsi %acc_i32 : i32 to i64
%requant_mult_i64 = arith.extsi %requant_mult_i32 : i32 to i64
%scaled_acc_i64 = arith.muli %acc_i64, %requant_mult_i64 : i64
// Apply rounding shift (add 0.5 before shifting)
%rounding_delta = arith.constant (1 << (%N - 1)) : i64
%scaled_acc_i64 = arith.addi %scaled_acc_i64, %rounding_delta : i64
// Perform the shift (arithmetic right shift)
%shifted_acc_i64 = arith.shrsi %scaled_acc_i64, %requant_shift_i32 : i64
```
5. Output Zero-Point Addition: Add the output zero point Zout.
mlir // Lowered IR Snippet 5: Add Output Zero Point %zp_out_i64 = arith.constant ... : i64 // Z_out extended to i64 %final_acc_i64 = arith.addi %shifted_acc_i64, %zp_out_i64 : i64
// Lowered IR Snippet 6: Clamp and Cast %min_val_i64 = arith.constant -128 : i64 %max_val_i64 = arith.constant 127 : i64 %clamped_acc_i64 = arith.maxsi %final_acc_i64, %min_val_i64 : i64 %clamped_acc_i64 = arith.minsi %clamped_acc_i64, %max_val_i64 : i64
// Truncate back to i8
%output_val_i8 = arith.trunci %clamped_acc_i64 : i64 to i8
// Store the final result
memref.store %output_val_i8, %output_q[...] : memref<...x!quant.uniform<i8:...>>
```
This sequence replaces the single high-level quant.conv2d
operation. The compiler performs this transformation based on the specific quantization parameters (scales, zero points) associated with the operation's inputs and outputs.
Flow transforming a high-level quantized operation into a sequence of lower-level integer arithmetic and control operations during compiler lowering.
Compiler frameworks like MLIR use dialect conversion and rewrite patterns to automate this lowering process. The specific sequence and optimizations (e.g., how zero points are handled, choice of fixed-point multiplication method) depend on the compiler's sophistication and the target hardware capabilities. For instance, hardware with specialized instructions like Intel's VNNI
or ARM's dot-product instructions might lead to a different lowered representation that directly maps to those instructions.
This hands-on perspective shows that "running a quantized model" involves significant compiler transformations. Understanding this lowering process is significant for debugging performance issues, designing custom quantized operators, or extending compilers to support new low-precision data types or hardware features. It bridges the abstract concept of quantization with the concrete integer arithmetic executed on the processor.
© 2025 ApX Machine Learning