To address the potential accuracy degradation seen with Post-Training Quantization (PTQ), especially at lower bit depths, Quantization-Aware Training (QAT) takes a different path. Instead of quantizing a fully trained model, QAT introduces the effects of quantization during the training or fine-tuning process itself. The core idea isn't to perform the entire training using low-precision arithmetic directly, which poses challenges for gradient calculation. Instead, QAT simulates the impact of quantization within the standard floating-point training framework.
The mechanism enabling this simulation is often called "fake quantization" or "simulated quantization". It involves inserting special nodes or operations into the model's computation graph during training. These fake quantization operations perform the following steps in the forward pass:
The output of a fake quantization node is therefore still a floating-point tensor. However, its values are now constrained; they represent only those specific floating-point numbers that can be exactly represented by the target low-precision data type (e.g., INT8).
Mathematically, if x is the input floating-point tensor, the fake quantization process xfq can be represented as:
xfq=dequantize(quantize(x,scale,zero_point),scale,zero_point)Here, quantize maps x to the low-precision domain (like INT8), and dequantize maps it back to the floating-point domain. The key is that xfq contains the error or information loss inherent in the quantization process.
During the forward pass of training, these fake quantization operations are placed strategically within the model architecture. Common locations include:
The diagram below illustrates conceptually how a fake quant node alters the forward pass for an activation tensor flowing between two layers.
Conceptual flow comparing standard training and QAT forward pass. In QAT, the fake quantization node simulates the effect of quantization on the activations before they are passed to the next layer.
By performing calculations with these slightly altered, "quantization-error-aware" tensors (xfq), the model learns to adjust its weights during training. The optimization process (like Stochastic Gradient Descent) implicitly minimizes the loss function considering the precision limitations that will be present after true quantization. The model adapts its parameters to become more robust to the noise and information loss introduced by the quantization process.
While the forward pass simulation is relatively straightforward, the backward pass presents a challenge. The quantization function itself (the mapping from float to integer) involves rounding or flooring operations, which have zero gradients almost everywhere. This non-differentiability means standard backpropagation cannot compute gradients through the quantize
step within the fake quantization node.
To overcome this, QAT relies on an approximation technique for calculating gradients during the backward pass. This technique, known as the Straight-Through Estimator (STE), effectively allows gradients to bypass the non-differentiable quantization step, enabling the model's weights to be updated. We will examine the STE in detail in the next section.
In summary, simulating quantization effects during training via fake quantization nodes is the core mechanism of QAT. It allows the model to proactively adapt to the constraints of low-precision arithmetic, often leading to significantly better accuracy in the final quantized model compared to PTQ, especially for aggressive quantization (e.g., 4-bit).
© 2025 ApX Machine Learning