Alright, let's put theory into practice. We've discussed how quantized operations are represented conceptually and within compiler IRs. Now, we'll focus on the critical step of lowering: transforming these high-level quantized operations into sequences of lower-level, often standard integer, arithmetic operations that explicitly handle scaling, zero-point adjustments, and type conversions. This process is fundamental for generating executable code, especially for hardware that lacks direct support for arbitrarily scaled quantized types but possesses efficient integer arithmetic units.
Imagine we have a high-level quantized operation in our IR, perhaps representing a 2D convolution or matrix multiplication. Conceptually, it looks something like this (using a simplified, MLIR-inspired notation):
// High-Level Representation (Conceptual)
%input_q = "quant.cast"(%input_fp32) : tensor<1xHxWxCinxf32> -> tensor<1xHxWxCin x !quant.uniform<i8:...>>
%weight_q = "quant.cast"(%weight_fp32) : tensor<CoutxKhxKwxCinxf32> -> tensor<CoutxKhxKwxCin x !quant.uniform<i8:...>>
%bias_q = "quant.cast"(%bias_fp32) : tensor<Coutxf32> -> tensor<Cout x !quant.uniform<i32:...>> // Bias often INT32
// Represents the entire quantized convolution including output scaling
%output_q = "quant.conv2d"(%input_q, %weight_q, %bias_q) {
strides = [...], padding = [...],
output_scale = ..., output_zero_point = ...
} : (..., ..., ...) -> tensor<1xHoWxWoxCout x !quant.uniform<i8:...>>
This quant.conv2d
operation encapsulates the core computation along with the necessary requantization logic to produce the final INT8 output. Our goal is to lower this into operations available in standard dialects (like arith
for arithmetic, memref
for memory access in MLIR) or into target-specific intrinsics.
Recall the affine quantization formula: r=S×(q−Z), where r is the real value, q is the quantized integer value, S is the scale, and Z is the zero point.
For a convolution (or matrix multiplication), the core computation involves multiply-accumulate operations. Let's consider a single output element calculation, which is a sum of products: Outputreal=∑(Inputreal×Weightreal)+Biasreal.
Substituting the quantization formula:
Sout(Outputq−Zout)≈∑[Sin(Inputq−Zin)×Sw(Weightq−Zw)]+Sbias(Biasq−Zbias)Our target is to compute Outputq using integer arithmetic. Rearranging the equation involves:
The full expansion can get complex:
Sout(Outputq−Zout)≈SinSw∑(Inputq−Zin)(Weightq−Zw)+Sbias(Biasq−Zbias) Sout(Outputq−Zout)≈SinSw∑[InputqWeightq−InputqZw−ZinWeightq+ZinZw]+Sbias(Biasq−Zbias)The compiler's task during lowering is to generate efficient integer code that computes the right-hand side and then solves for Outputq.
A common strategy involves these steps, which would be reflected in the generated lower-level IR:
Integer Accumulation: Perform the primary multiply-accumulate operations using the quantized integer inputs (Inputq, Weightq). This results in a wider intermediate representation, typically INT32.
// Conceptual Lowered IR Snippet 1: Integer MatMul/Conv Accumulation
%acc_i32 = arith.constant 0 : i32
// Loop over reduction dimensions (e.g., input channels, kernel spatial dims)
affine.for %k = 0 to K {
%in_val_i8 = memref.load %input_q[...] : memref<...x!quant.uniform<i8:...>>
%wt_val_i8 = memref.load %weight_q[...] : memref<...x!quant.uniform<i8:...>>
// Extend i8 to i32 for accumulation
%in_val_i32 = arith.extsi %in_val_i8 : i8 to i32
%wt_val_i32 = arith.extsi %wt_val_i8 : i8 to i32
// Integer multiplication
%mul_i32 = arith.muli %in_val_i32, %wt_val_i32 : i32
// Accumulate
%acc_i32 = arith.addi %acc_i32, %mul_i32 : i32
}
Zero-Point Correction: Apply corrections based on input (Zin) and weight (Zw) zero points. This can be done in several ways:
(Input_q - Z_in) * (Weight_q - Z_w)
. This adds subtractions inside the loop.// Conceptual Lowered IR Snippet 2: Zero-Point Correction (Post-Accumulation style)
// Assume %sum_inputs_i32, %sum_weights_i32, %reduction_size are pre-calculated or available
%zp_in_i32 = arith.constant ... : i32 // Z_in
%zp_wt_i32 = arith.constant ... : i32 // Z_w
%correction1 = arith.muli %sum_inputs_i32, %zp_wt_i32 : i32
%correction2 = arith.muli %sum_weights_i32, %zp_in_i32 : i32
%correction3 = arith.muli %reduction_size, %zp_in_i32 : i32
%correction3 = arith.muli %correction3, %zp_wt_i32 : i32 // Term Z_in * Z_w * K
%acc_i32 = arith.subi %acc_i32, %correction1 : i32
%acc_i32 = arith.subi %acc_i32, %correction2 : i32
%acc_i32 = arith.addi %acc_i32, %correction3 : i32
Bias Addition: Add the quantized bias term (Biasq). Remember, the bias scale Sbias must ideally equal Sin×Sw for direct INT32 addition. If not, the bias must be rescaled before addition.
// Conceptual Lowered IR Snippet 3: Bias Addition
%bias_val_i32 = memref.load %bias_q[...] : memref<...x!quant.uniform<i32:...>>
// Assume bias scale matches S_in * S_w, otherwise rescaling needed here
%acc_i32 = arith.addi %acc_i32, %bias_val_i32 : i32
Requantization Scaling: Apply the final scale factor to convert the INT32 accumulator back towards the target output precision (e.g., INT8). The scale factor is M=SoutSinSw. This is almost never an integer or power of two, so it's implemented using fixed-point multiplication. The compiler calculates an integer multiplier M0 and a right-shift amount N such that M≈M0/2N.
// Conceptual Lowered IR Snippet 4: Requantization Scaling
%requant_mult_i32 = arith.constant ... : i32 // M0
%requant_shift_i32 = arith.constant ... : i32 // N
// Perform fixed-point multiplication: (acc * M0) >> N
// Often requires widening to i64 to avoid overflow during multiplication
%acc_i64 = arith.extsi %acc_i32 : i32 to i64
%requant_mult_i64 = arith.extsi %requant_mult_i32 : i32 to i64
%scaled_acc_i64 = arith.muli %acc_i64, %requant_mult_i64 : i64
// Apply rounding shift (add 0.5 before shifting)
%rounding_delta = arith.constant (1 << (%N - 1)) : i64
%scaled_acc_i64 = arith.addi %scaled_acc_i64, %rounding_delta : i64
// Perform the shift (arithmetic right shift)
%shifted_acc_i64 = arith.shrsi %scaled_acc_i64, %requant_shift_i32 : i64
Output Zero-Point Addition: Add the output zero point Zout.
// Conceptual Lowered IR Snippet 5: Add Output Zero Point
%zp_out_i64 = arith.constant ... : i64 // Z_out extended to i64
%final_acc_i64 = arith.addi %shifted_acc_i64, %zp_out_i64 : i64
Clamping and Casting: Clamp the result to the valid range of the target output type (e.g., [-128, 127] for INT8) and cast to the final type.
// Conceptual Lowered IR Snippet 6: Clamp and Cast
%min_val_i64 = arith.constant -128 : i64
%max_val_i64 = arith.constant 127 : i64
%clamped_acc_i64 = arith.maxsi %final_acc_i64, %min_val_i64 : i64
%clamped_acc_i64 = arith.minsi %clamped_acc_i64, %max_val_i64 : i64
// Truncate back to i8
%output_val_i8 = arith.trunci %clamped_acc_i64 : i64 to i8
// Store the final result
memref.store %output_val_i8, %output_q[...] : memref<...x!quant.uniform<i8:...>>
This sequence replaces the single high-level quant.conv2d
operation. The compiler performs this transformation based on the specific quantization parameters (scales, zero points) associated with the operation's inputs and outputs.
Flow transforming a high-level quantized operation into a sequence of lower-level integer arithmetic and control operations during compiler lowering.
Compiler frameworks like MLIR use dialect conversion and rewrite patterns to automate this lowering process. The specific sequence and optimizations (e.g., how zero points are handled, choice of fixed-point multiplication method) depend on the compiler's sophistication and the target hardware capabilities. For instance, hardware with specialized instructions like Intel's VNNI
or ARM's dot-product instructions might lead to a different lowered representation that directly maps to those instructions.
This hands-on perspective shows that "running a quantized model" involves significant compiler transformations. Understanding this lowering process is significant for debugging performance issues, designing custom quantized operators, or extending compilers to support new low-precision data types or hardware features. It bridges the abstract concept of quantization with the concrete integer arithmetic executed on the processor.
© 2025 ApX Machine Learning