Building upon the concepts of quantization discussed earlier, this section provides practical guidance on implementing two primary techniques: Post-Training Quantization (PTQ) and Quantization-Aware Training (QAT) for large language models. We will transition from theory to application, focusing on the typical workflow, code implementation patterns using common libraries, and evaluating the resulting models. The goal is to equip you with the hands-on skills needed to apply these powerful optimization methods effectively.We assume you are comfortable with Python and deep learning frameworks like PyTorch, along with the Hugging Face Transformers library. For this practical exercise, we'll outline the steps using PyTorch's built-in quantization tools and concepts adaptable to libraries like Hugging Face Optimum or Intel's Neural Compressor, which often provide higher-level APIs.Environment SetupBefore starting, ensure your environment includes the necessary libraries:torch and torchvision (PyTorch core and utilities)transformers (for loading pre-trained LLMs and tokenizers)datasets (for handling calibration and evaluation data)Optionally, optimum or other specialized quantization libraries for streamlined workflows.We recommend using a relatively small pre-trained transformer model (like distilgpt2, bert-base-uncased, or a manageable slice of a larger model) for these examples to keep computation times reasonable.# Example environment setup # pip install torch torchvision torchaudio # pip install transformers datasets evaluate accelerate bitsandbytes # Add others as neededImplementing Post-Training Quantization (PTQ)PTQ is attractive because it doesn't require retraining the model. It involves calibrating the model on a small, representative dataset to determine the optimal quantization parameters (scale and zero-point) for weights and activations.1. Model and Tokenizer LoadingFirst, load your target pre-trained model and its corresponding tokenizer. Ensure the model is set to evaluation mode.from transformers import AutoModelForCausalLM, AutoTokenizer import torch model_id = "distilgpt2" # Example model model_fp32 = AutoModelForCausalLM.from_pretrained(model_id) tokenizer = AutoTokenizer.from_pretrained(model_id) model_fp32.eval() # Set to evaluation mode2. Calibration Dataset PreparationSelect a dataset that reflects the kind of data the model will encounter during inference. A few hundred samples are often sufficient. Preprocess this data using the model's tokenizer.from datasets import load_dataset from torch.utils.data import DataLoader # Load a small calibration dataset (e.g., 100 examples from wikitext) calibration_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:100]") def preprocess_function(examples): # Adjust max_length as needed for your model/task return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=128) tokenized_calibration_data = calibration_data.map(preprocess_function, batched=True) tokenized_calibration_data.set_format(type="torch", columns=["input_ids", "attention_mask"]) # Use a DataLoader for batching calibration_dataloader = DataLoader(tokenized_calibration_data, batch_size=8)3. PTQ Configuration and CalibrationConfigure the quantization settings. PyTorch's torch.quantization module offers defaults (like get_default_qconfig) or allows fine-grained control (e.g., symmetric per-channel for weights, asymmetric per-tensor for activations).Prepare the model by inserting observers (torch.quantization.prepare_ptq). These observers collect statistics about the activation ranges during calibration.# Configure PTQ - using a common backend like 'x86' or 'qnnpack' # 'fbgemm' or 'qnnpack' often used for server (x86) or mobile (ARM) qconfig = torch.quantization.get_default_qconfig('fbgemm') # For x86 # Prepare model for PTQ: inserts observers # Note: `prepare_ptq` is evolving; check PyTorch docs for latest API. # Some libraries might use different functions or context managers. model_to_quantize = copy.deepcopy(model_fp32) # Work on a copy model_to_quantize.qconfig = qconfig torch.quantization.prepare(model_to_quantize, inplace=True) # Example using older API pattern # Calibrate the model by running data through it print("Running calibration...") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model_to_quantize.to(device) with torch.no_grad(): for batch in calibration_dataloader: input_ids = batch['input_ids'].to(device) attention_mask = batch['attention_mask'].to(device) try: model_to_quantize(input_ids=input_ids, attention_mask=attention_mask) except Exception as e: print(f"Warning: Batch failed during calibration: {e}") # Handle potential issues # Depending on model/library, might need specific input formats print("Calibration finished.")Note: Handling model inputs and outputs correctly during calibration is important. Some model architectures or quantization backends might have specific requirements. The prepare_ptq API in PyTorch is also undergoing changes, so consult the documentation for your specific version. Libraries like Hugging Face Optimum often simplify this process significantly.4. Conversion to Quantized ModelAfter calibration, convert the model to its quantized equivalent using torch.quantization.convert. This replaces observed modules with their quantized counterparts.model_to_quantize.cpu() # Conversion often done on CPU model_ptq = torch.quantization.convert(model_to_quantize, inplace=False) print("PTQ conversion complete.")5. EvaluationEvaluate the PTQ model against the original FP32 model. Important metrics include:Model Size: Compare the file size of the saved state dictionaries.Accuracy/Fidelity: Measure performance on a relevant downstream task (e.g., perplexity for language modeling, accuracy for classification).Inference Speed: Benchmark latency and throughput on target hardware.import os import time from evaluate import load # Hugging Face Evaluate library # 1. Model Size torch.save(model_fp32.state_dict(), "distilgpt2_fp32.pth") torch.save(model_ptq.state_dict(), "distilgpt2_ptq_int8.pth") fp32_size = os.path.getsize("distilgpt2_fp32.pth") / 1e6 ptq_size = os.path.getsize("distilgpt2_ptq_int8.pth") / 1e6 print(f"FP32 Model Size: {fp32_size:.2f} MB") print(f"PTQ INT8 Model Size: {ptq_size:.2f} MB") print(f"Size Reduction: {(1 - ptq_size / fp32_size) * 100:.2f}%") # 2. Accuracy (Example: Perplexity on a test set) perplexity = load("perplexity", module_type="metric") test_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test[:50]") # Small test set test_encodings = tokenizer("\n\n".join(test_data["text"]), return_tensors="pt") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model_fp32.to(device) model_ptq.to(device) # Ensure model is on correct device for eval with torch.no_grad(): results_fp32 = perplexity.compute(model_id=model_id, # Used internally by HF evaluate model=model_fp32, tokenizer=tokenizer, predictions=test_encodings.input_ids.to(device), batch_size=4) # Adjust batch size results_ptq = perplexity.compute(model_id=model_id, # Provide necessary args model=model_ptq, tokenizer=tokenizer, predictions=test_encodings.input_ids.to(device), batch_size=4) print(f"FP32 Perplexity: {results_fp32['mean_perplexity']:.4f}") print(f"PTQ INT8 Perplexity: {results_ptq['mean_perplexity']:.4f}") # 3. Inference Speed (Simple Latency Example) dummy_input = tokenizer("This is a sample text for benchmarking.", return_tensors="pt").input_ids.to(device) repetitions = 100 # Warm-up runs for _ in range(10): _ = model_fp32(dummy_input) _ = model_ptq(dummy_input) # Timed runs start_time = time.time() for _ in range(repetitions): _ = model_fp32(dummy_input) fp32_latency = (time.time() - start_time) / repetitions * 1000 # milliseconds start_time = time.time() for _ in range(repetitions): _ = model_ptq(dummy_input) ptq_latency = (time.time() - start_time) / repetitions * 1000 # milliseconds print(f"FP32 Average Latency: {fp32_latency:.2f} ms") print(f"PTQ INT8 Average Latency: {ptq_latency:.2f} ms") print(f"Speedup: {fp32_latency / ptq_latency:.2f}x")PTQ typically yields significant size reduction and speedup, especially on hardware with native INT8 support, but often comes with a small drop in accuracy.Implementing Quantization-Aware Training (QAT)QAT simulates quantization effects during a fine-tuning phase, allowing the model to adapt its weights to the reduced precision. This usually recovers some or all of the accuracy lost during PTQ, at the cost of requiring additional training compute.1. Model Preparation for QATLoad the pre-trained model. Instead of just evaluation mode, you'll start in training mode. Use a QAT-specific configuration (get_default_qat_qconfig) and prepare the model using torch.quantization.prepare_qat. This inserts "fake quantization" modules that mimic the effects of quantization during both forward and backward passes.# Load the base FP32 model again model_for_qat = AutoModelForCausalLM.from_pretrained(model_id) # Configure QAT # Note: Backend choice ('fbgemm', 'qnnpack') affects supported ops and performance qat_qconfig = torch.quantization.get_default_qat_qconfig('fbgemm') # Prepare model for QAT: inserts fake quantization modules model_for_qat.train() # Set to training mode model_for_qat.qconfig = qat_qconfig # Important: Layer fusion can improve QAT accuracy and performance # Example: Fuse Conv-BN-ReLU, Linear-ReLU etc. where applicable # fuse_list = torch.quantization.fuse_modules() # Define applicable layers to fuse # model_for_qat = torch.quantization.fuse_modules(model_for_qat, fuse_list) # Apply fusion torch.quantization.prepare_qat(model_for_qat, inplace=True) print("Model prepared for QAT.")Note: For transformer models, identifying fusible layers (like Linear -> Activation) can be beneficial but requires careful inspection of the model architecture. Some libraries automate parts of this.2. Fine-tuning PhaseFine-tune the prepared model on a relevant dataset (e.g., the downstream task dataset or a general text corpus) for a small number of epochs. Use a standard training loop, but be mindful that QAT often requires lower learning rates and careful hyperparameter tuning compared to standard FP32 fine-tuning. The gradients will flow through the fake quantization operations, allowing the model to adjust.from torch.optim import AdamW from transformers import get_scheduler # Prepare training dataset (e.g., fine-tuning on a specific task or continuing LM) # Use a different split or dataset than calibration for realism train_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[100:1100]") # Example: 1000 samples tokenized_train_data = train_dataset.map(preprocess_function, batched=True) tokenized_train_data.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"]) # Add labels if needed for task train_dataloader = DataLoader(tokenized_train_data, batch_size=4, shuffle=True) # Smaller batch size often helps # Optimizer and Scheduler optimizer = AdamW(model_for_qat.parameters(), lr=1e-5) # Typically lower LR for QAT num_training_steps = len(train_dataloader) # Adjust epochs as needed (e.g., 1-3) lr_scheduler = get_scheduler( "linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps ) # Fine-tuning Loop print("Starting QAT fine-tuning...") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model_for_qat.to(device) model_for_qat.train() for epoch in range(1): # Example: 1 epoch for batch in train_dataloader: optimizer.zero_grad() # Move batch to device input_ids = batch['input_ids'].to(device) attention_mask = batch['attention_mask'].to(device) # Assuming causal LM, labels are usually input_ids shifted labels = input_ids.clone() outputs = model_for_qat(input_ids=input_ids, attention_mask=attention_mask, labels=labels) loss = outputs.loss loss.backward() # Gradients flow through fake quant nodes optimizer.step() lr_scheduler.step() print(f"QAT Epoch {epoch+1} complete. Last loss: {loss.item():.4f}") print("QAT fine-tuning finished.")3. Conversion to Quantized ModelAfter fine-tuning, convert the QAT model to a truly quantized integer model, similar to the PTQ process. The model must be in evaluation mode for conversion.model_for_qat.cpu() # Conversion often done on CPU model_for_qat.eval() # Set to eval mode before final conversion model_qat = torch.quantization.convert(model_for_qat, inplace=False) print("QAT conversion complete.")4. EvaluationEvaluate the QAT model using the same metrics as the PTQ model (size, accuracy, speed). Compare the results against both the FP32 baseline and the PTQ model.# 1. Model Size torch.save(model_qat.state_dict(), "distilgpt2_qat_int8.pth") qat_size = os.path.getsize("distilgpt2_qat_int8.pth") / 1e6 print(f"QAT INT8 Model Size: {qat_size:.2f} MB") # Should be similar to PTQ size # 2. Accuracy (Perplexity) model_qat.to(device) with torch.no_grad(): results_qat = perplexity.compute(model_id=model_id, model=model_qat, tokenizer=tokenizer, predictions=test_encodings.input_ids.to(device), batch_size=4) print(f"FP32 Perplexity: {results_fp32['mean_perplexity']:.4f}") # From PTQ eval print(f"PTQ INT8 Perplexity: {results_ptq['mean_perplexity']:.4f}") # From PTQ eval print(f"QAT INT8 Perplexity: {results_qat['mean_perplexity']:.4f}") # 3. Inference Speed (Latency) # Warm-up for _ in range(10): _ = model_qat(dummy_input) start_time = time.time() for _ in range(repetitions): _ = model_qat(dummy_input) qat_latency = (time.time() - start_time) / repetitions * 1000 print(f"FP32 Average Latency: {fp32_latency:.2f} ms") # From PTQ eval print(f"PTQ INT8 Average Latency: {ptq_latency:.2f} ms") # From PTQ eval print(f"QAT INT8 Average Latency: {qat_latency:.2f} ms") print(f"QAT Speedup vs FP32: {fp32_latency / qat_latency:.2f}x")You should typically observe that QAT achieves better accuracy than PTQ, potentially closing the gap with the original FP32 model significantly, while offering similar size reduction and speedup benefits.Comparing PTQ and QAT ResultsLet's visualize the potential trade-offs based on the results obtained above.{"layout": {"title": "Quantization Trade-offs: Perplexity vs. Model Size", "xaxis": {"title": "Model Size (MB)"}, "yaxis": {"title": "Perplexity (Lower is Better)", "range": [10, 25]}, "legend": {"title": "Model Type"}}, "data": [{"type": "scatter", "mode": "markers+text", "x": [315, 85, 85], "y": [18.5, 20.5, 19.0], "text": ["FP32", "PTQ INT8", "QAT INT8"], "textposition": "top right", "marker": {"size": [12, 12, 12], "color": ["#1c7ed6", "#fd7e14", "#40c057"]}, "name": "DistilGPT2"}]}Comparison of model sizes and perplexity scores for DistilGPT2 (FP32) versus its PTQ and QAT INT8 quantized versions. Lower perplexity indicates better language modeling performance. PTQ reduces size significantly but slightly increases perplexity. QAT maintains the size reduction while recovering some of the performance lost by PTQ.Advanced Next StepsThis practical demonstrated the core workflows for PTQ and QAT using INT8 precision. In practice, you would extend this by:Exploring Different Precisions: Applying the same principles for lower precisions like INT4, NF4, or FP4, potentially using libraries like bitsandbytes or specialized hardware backends.Mixed Precision: Implementing strategies where different parts of the model use different precisions based on sensitivity analysis.Hardware-Specific Kernels: Leveraging optimized kernels (e.g., via TensorRT, ONNX Runtime, vLLM) that provide maximum speedup for quantized operations on specific GPUs or accelerators. This is covered in more detail in Chapter 6.Integration with Other Techniques: Combining quantization with pruning or PEFT methods like LoRA (as seen in QLoRA) for further optimization, discussed in later chapters.Evaluation: Performing comprehensive evaluations across multiple datasets and tasks to fully understand the impact of quantization on model capabilities, including fairness and checks (Chapter 7).This hands-on exercise forms a foundation for applying advanced quantization. Mastering PTQ and QAT allows you to significantly reduce the resource requirements of LLMs, making deployment more feasible across a wider range of applications and hardware platforms. Remember that the best approach (PTQ vs. QAT) and configuration often depend on the specific model, task, acceptable accuracy trade-off, and available compute resources for fine-tuning.