Building upon the concepts of quantization discussed earlier, this section provides practical guidance on implementing two primary techniques: Post-Training Quantization (PTQ) and Quantization-Aware Training (QAT) for large language models. We will transition from theory to application, focusing on the typical workflow, code implementation patterns using common libraries, and evaluating the resulting models. The goal is to equip you with the hands-on skills needed to apply these powerful optimization methods effectively.
We assume you are comfortable with Python and deep learning frameworks like PyTorch, along with the Hugging Face Transformers library. For this practical exercise, we'll outline the steps using PyTorch's built-in quantization tools and concepts adaptable to libraries like Hugging Face Optimum or Intel's Neural Compressor, which often provide higher-level APIs.
Before starting, ensure your environment includes the necessary libraries:
torch
and torchvision
(PyTorch core and utilities)transformers
(for loading pre-trained LLMs and tokenizers)datasets
(for handling calibration and evaluation data)optimum
or other specialized quantization libraries for streamlined workflows.We recommend using a relatively small pre-trained transformer model (like distilgpt2
, bert-base-uncased
, or a manageable slice of a larger model) for these examples to keep computation times reasonable.
# Example environment setup (conceptual)
# pip install torch torchvision torchaudio
# pip install transformers datasets evaluate accelerate bitsandbytes # Add others as needed
PTQ is attractive because it doesn't require retraining the model. It involves calibrating the model on a small, representative dataset to determine the optimal quantization parameters (scale and zero-point) for weights and activations.
First, load your target pre-trained model and its corresponding tokenizer. Ensure the model is set to evaluation mode.
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model_id = "distilgpt2" # Example model
model_fp32 = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model_fp32.eval() # Set to evaluation mode
Select a dataset that reflects the kind of data the model will encounter during inference. A few hundred samples are often sufficient. Preprocess this data using the model's tokenizer.
from datasets import load_dataset
from torch.utils.data import DataLoader
# Load a small calibration dataset (e.g., 100 examples from wikitext)
calibration_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:100]")
def preprocess_function(examples):
# Adjust max_length as needed for your model/task
return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=128)
tokenized_calibration_data = calibration_data.map(preprocess_function, batched=True)
tokenized_calibration_data.set_format(type="torch", columns=["input_ids", "attention_mask"])
# Use a DataLoader for batching
calibration_dataloader = DataLoader(tokenized_calibration_data, batch_size=8)
Configure the quantization settings. PyTorch's torch.quantization
module offers defaults (like get_default_qconfig
) or allows fine-grained control (e.g., symmetric per-channel for weights, asymmetric per-tensor for activations).
Prepare the model by inserting observers (torch.quantization.prepare_ptq
). These observers collect statistics about the activation ranges during calibration.
# Configure PTQ - using a common backend like 'x86' or 'qnnpack'
# 'fbgemm' or 'qnnpack' often used for server (x86) or mobile (ARM)
qconfig = torch.quantization.get_default_qconfig('fbgemm') # For x86
# Prepare model for PTQ: inserts observers
# Note: `prepare_ptq` is evolving; check PyTorch docs for latest API.
# Some libraries might use different functions or context managers.
model_to_quantize = copy.deepcopy(model_fp32) # Work on a copy
model_to_quantize.qconfig = qconfig
torch.quantization.prepare(model_to_quantize, inplace=True) # Example using older API pattern
# Calibrate the model by running data through it
print("Running calibration...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_to_quantize.to(device)
with torch.no_grad():
for batch in calibration_dataloader:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
try:
model_to_quantize(input_ids=input_ids, attention_mask=attention_mask)
except Exception as e:
print(f"Warning: Batch failed during calibration: {e}") # Handle potential issues
# Depending on model/library, might need specific input formats
print("Calibration finished.")
Note: Handling model inputs and outputs correctly during calibration is important. Some model architectures or quantization backends might have specific requirements. The prepare_ptq
API in PyTorch is also undergoing changes, so consult the documentation for your specific version. Libraries like Hugging Face Optimum often simplify this process significantly.
After calibration, convert the model to its quantized equivalent using torch.quantization.convert
. This replaces observed modules with their quantized counterparts.
model_to_quantize.cpu() # Conversion often done on CPU
model_ptq = torch.quantization.convert(model_to_quantize, inplace=False)
print("PTQ conversion complete.")
Evaluate the PTQ model against the original FP32 model. Key metrics include:
import os
import time
from evaluate import load # Hugging Face Evaluate library
# 1. Model Size
torch.save(model_fp32.state_dict(), "distilgpt2_fp32.pth")
torch.save(model_ptq.state_dict(), "distilgpt2_ptq_int8.pth")
fp32_size = os.path.getsize("distilgpt2_fp32.pth") / 1e6
ptq_size = os.path.getsize("distilgpt2_ptq_int8.pth") / 1e6
print(f"FP32 Model Size: {fp32_size:.2f} MB")
print(f"PTQ INT8 Model Size: {ptq_size:.2f} MB")
print(f"Size Reduction: {(1 - ptq_size / fp32_size) * 100:.2f}%")
# 2. Accuracy (Example: Perplexity on a test set)
perplexity = load("perplexity", module_type="metric")
test_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test[:50]") # Small test set
test_encodings = tokenizer("\n\n".join(test_data["text"]), return_tensors="pt")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_fp32.to(device)
model_ptq.to(device) # Ensure model is on correct device for eval
with torch.no_grad():
results_fp32 = perplexity.compute(model_id=model_id, # Used internally by HF evaluate
model=model_fp32,
tokenizer=tokenizer,
predictions=test_encodings.input_ids.to(device),
batch_size=4) # Adjust batch size
results_ptq = perplexity.compute(model_id=model_id, # Provide necessary args
model=model_ptq,
tokenizer=tokenizer,
predictions=test_encodings.input_ids.to(device),
batch_size=4)
print(f"FP32 Perplexity: {results_fp32['mean_perplexity']:.4f}")
print(f"PTQ INT8 Perplexity: {results_ptq['mean_perplexity']:.4f}")
# 3. Inference Speed (Simple Latency Example)
dummy_input = tokenizer("This is a sample text for benchmarking.", return_tensors="pt").input_ids.to(device)
repetitions = 100
# Warm-up runs
for _ in range(10):
_ = model_fp32(dummy_input)
_ = model_ptq(dummy_input)
# Timed runs
start_time = time.time()
for _ in range(repetitions):
_ = model_fp32(dummy_input)
fp32_latency = (time.time() - start_time) / repetitions * 1000 # milliseconds
start_time = time.time()
for _ in range(repetitions):
_ = model_ptq(dummy_input)
ptq_latency = (time.time() - start_time) / repetitions * 1000 # milliseconds
print(f"FP32 Average Latency: {fp32_latency:.2f} ms")
print(f"PTQ INT8 Average Latency: {ptq_latency:.2f} ms")
print(f"Speedup: {fp32_latency / ptq_latency:.2f}x")
PTQ typically yields significant size reduction and speedup, especially on hardware with native INT8 support, but often comes with a small drop in accuracy.
QAT simulates quantization effects during a fine-tuning phase, allowing the model to adapt its weights to the reduced precision. This usually recovers some or all of the accuracy lost during PTQ, at the cost of requiring additional training compute.
Load the pre-trained model. Instead of just evaluation mode, you'll start in training mode. Use a QAT-specific configuration (get_default_qat_qconfig
) and prepare the model using torch.quantization.prepare_qat
. This inserts "fake quantization" modules that mimic the effects of quantization during both forward and backward passes.
# Load the base FP32 model again
model_for_qat = AutoModelForCausalLM.from_pretrained(model_id)
# Configure QAT
# Note: Backend choice ('fbgemm', 'qnnpack') affects supported ops and performance
qat_qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
# Prepare model for QAT: inserts fake quantization modules
model_for_qat.train() # Set to training mode
model_for_qat.qconfig = qat_qconfig
# Important: Layer fusion can improve QAT accuracy and performance
# Example: Fuse Conv-BN-ReLU, Linear-ReLU etc. where applicable
# fuse_list = torch.quantization.fuse_modules() # Define applicable layers to fuse
# model_for_qat = torch.quantization.fuse_modules(model_for_qat, fuse_list) # Apply fusion
torch.quantization.prepare_qat(model_for_qat, inplace=True)
print("Model prepared for QAT.")
Note: For transformer models, identifying fusible layers (like Linear -> Activation) can be beneficial but requires careful inspection of the model architecture. Some libraries automate parts of this.
Fine-tune the prepared model on a relevant dataset (e.g., the downstream task dataset or a general text corpus) for a small number of epochs. Use a standard training loop, but be mindful that QAT often requires lower learning rates and careful hyperparameter tuning compared to standard FP32 fine-tuning. The gradients will flow through the fake quantization operations, allowing the model to adjust.
from torch.optim import AdamW
from transformers import get_scheduler
# Prepare training dataset (e.g., fine-tuning on a specific task or continuing LM)
# Use a different split or dataset than calibration for realism
train_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[100:1100]") # Example: 1000 samples
tokenized_train_data = train_dataset.map(preprocess_function, batched=True)
tokenized_train_data.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"]) # Add labels if needed for task
train_dataloader = DataLoader(tokenized_train_data, batch_size=4, shuffle=True) # Smaller batch size often helps
# Optimizer and Scheduler
optimizer = AdamW(model_for_qat.parameters(), lr=1e-5) # Typically lower LR for QAT
num_training_steps = len(train_dataloader) # Adjust epochs as needed (e.g., 1-3)
lr_scheduler = get_scheduler(
"linear",
optimizer=optimizer,
num_warmup_steps=0,
num_training_steps=num_training_steps
)
# Fine-tuning Loop
print("Starting QAT fine-tuning...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_for_qat.to(device)
model_for_qat.train()
for epoch in range(1): # Example: 1 epoch
for batch in train_dataloader:
optimizer.zero_grad()
# Move batch to device
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
# Assuming causal LM, labels are usually input_ids shifted
labels = input_ids.clone()
outputs = model_for_qat(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
loss = outputs.loss
loss.backward() # Gradients flow through fake quant nodes
optimizer.step()
lr_scheduler.step()
print(f"QAT Epoch {epoch+1} complete. Last loss: {loss.item():.4f}")
print("QAT fine-tuning finished.")
After fine-tuning, convert the QAT model to a truly quantized integer model, similar to the PTQ process. The model must be in evaluation mode for conversion.
model_for_qat.cpu() # Conversion often done on CPU
model_for_qat.eval() # Set to eval mode before final conversion
model_qat = torch.quantization.convert(model_for_qat, inplace=False)
print("QAT conversion complete.")
Evaluate the QAT model using the same metrics as the PTQ model (size, accuracy, speed). Compare the results against both the FP32 baseline and the PTQ model.
# 1. Model Size
torch.save(model_qat.state_dict(), "distilgpt2_qat_int8.pth")
qat_size = os.path.getsize("distilgpt2_qat_int8.pth") / 1e6
print(f"QAT INT8 Model Size: {qat_size:.2f} MB") # Should be similar to PTQ size
# 2. Accuracy (Perplexity)
model_qat.to(device)
with torch.no_grad():
results_qat = perplexity.compute(model_id=model_id,
model=model_qat,
tokenizer=tokenizer,
predictions=test_encodings.input_ids.to(device),
batch_size=4)
print(f"FP32 Perplexity: {results_fp32['mean_perplexity']:.4f}") # From PTQ eval
print(f"PTQ INT8 Perplexity: {results_ptq['mean_perplexity']:.4f}") # From PTQ eval
print(f"QAT INT8 Perplexity: {results_qat['mean_perplexity']:.4f}")
# 3. Inference Speed (Latency)
# Warm-up
for _ in range(10):
_ = model_qat(dummy_input)
start_time = time.time()
for _ in range(repetitions):
_ = model_qat(dummy_input)
qat_latency = (time.time() - start_time) / repetitions * 1000
print(f"FP32 Average Latency: {fp32_latency:.2f} ms") # From PTQ eval
print(f"PTQ INT8 Average Latency: {ptq_latency:.2f} ms") # From PTQ eval
print(f"QAT INT8 Average Latency: {qat_latency:.2f} ms")
print(f"QAT Speedup vs FP32: {fp32_latency / qat_latency:.2f}x")
You should typically observe that QAT achieves better accuracy than PTQ, potentially closing the gap with the original FP32 model significantly, while offering similar size reduction and speedup benefits.
Let's visualize the potential trade-offs based on the hypothetical results obtained above.
Comparison of hypothetical model sizes and perplexity scores for DistilGPT2 (FP32) versus its PTQ and QAT INT8 quantized versions. Lower perplexity indicates better language modeling performance. PTQ reduces size significantly but slightly increases perplexity. QAT maintains the size reduction while recovering some of the performance lost by PTQ.
This practical demonstrated the core workflows for PTQ and QAT using INT8 precision. In practice, you would extend this by:
bitsandbytes
or specialized hardware backends.This hands-on exercise forms a foundation for applying advanced quantization. Mastering PTQ and QAT allows you to significantly reduce the resource requirements of LLMs, making deployment more feasible across a wider range of applications and hardware platforms. Remember that the best approach (PTQ vs. QAT) and configuration often depend on the specific model, task, acceptable accuracy trade-off, and available compute resources for fine-tuning.
© 2025 ApX Machine Learning