Masterclass
Knowledge distillation offers a distinct approach to model compression. Instead of modifying the parameters of the large model directly through quantization or pruning, distillation focuses on transferring the knowledge from a large, pre-trained model (the "teacher") to a smaller, more efficient model (the "student"). The objective is to train a student model that achieves significantly better performance than if it were trained from scratch with the same architecture, by learning from the richer output signals provided by the teacher. This makes the student model more suitable for deployment scenarios with limited computational resources or strict latency requirements.
The core idea hinges on the observation that a well-trained large model captures complex patterns and nuances from the data, which are reflected not only in its final predictions but also in its internal representations and output probability distributions. The student model learns by minimizing a loss function that encourages it to mimic these behaviors of the teacher, in addition to learning the original task objective using ground-truth labels.
Several strategies exist for transferring knowledge from the teacher to the student:
Matching Output Logits (Soft Targets): This is the most common form of knowledge distillation. Instead of training the student solely on the hard ground-truth labels (e.g., one-hot encoded vectors), it's also trained to match the probability distribution produced by the teacher model over the possible output classes or tokens. To provide a richer learning signal, the outputs of both the teacher and student models are often "softened" using a temperature parameter (T) in the softmax function:
pi=∑jexp(zj/T)exp(zi/T)Here, zi represents the logit for class i. A higher temperature (T>1) produces a softer probability distribution over classes, revealing more information about the teacher's internal "confidence" and similarity structure between classes. A temperature T=1 corresponds to the standard softmax. The distillation loss is typically the Kullback-Leibler (KL) divergence or Mean Squared Error (MSE) between the softened probability distributions of the teacher (pT) and the student (pS):
LKD=KL(pT∣∣pS)=i∑piTlog(piSpiT)or
LKD=MSE(zT,zS)where zT and zS are the logits (pre-softmax outputs) of the teacher and student, respectively. Using logits directly with MSE can sometimes be simpler and equally effective.
Matching Intermediate Features: Knowledge can also be transferred by encouraging the student model to replicate the activations or hidden states from intermediate layers of the teacher model. This forces the student to learn similar internal representations. A loss function, often MSE, is calculated between the feature maps of corresponding layers in the teacher and student.
LFeature=MSE(fT(x),fS(x))Here, fT(x) and fS(x) represent the feature activations from selected layers of the teacher and student for input x. A challenge here is aligning layers, especially if the architectures differ significantly. Often, linear transformations are learned to map the student's features to the dimensionality of the teacher's features before calculating the loss.
Matching Attention Mechanisms: For Transformer-based models, the attention patterns learned by the teacher contain valuable relational information between tokens. Distillation can involve training the student to produce similar attention maps as the teacher. The loss is computed based on the difference between the attention weight matrices of corresponding layers or heads.
The standard training setup for knowledge distillation involves:
The student model is trained using a combined loss function, which is a weighted sum of the standard task loss (LTask, e.g., Cross-Entropy loss against ground-truth labels) and the distillation loss (LKD):
LTotal=αLTask+(1−α)LKDThe hyperparameter α (typically between 0 and 1) balances the importance of matching the ground truth versus mimicking the teacher. The temperature T used for softening the logits is another important hyperparameter.
Here's a PyTorch snippet illustrating the loss calculation for logit matching using KL divergence:
import torch
import torch.nn as nn
import torch.nn.functional as F
# Assume teacher_model and student_model are defined
# Assume inputs and labels are available from the dataloader
teacher_model.eval() # Teacher model is in evaluation mode and frozen
student_model.train() # Student model is in training mode
optimizer = torch.optim.AdamW(student_model.parameters(), lr=1e-4)
# Hyperparameters
temperature = 4.0
alpha = 0.3 # Weight for the standard task loss
# Standard Cross-Entropy loss
criterion_task = nn.CrossEntropyLoss()
# KL Divergence loss for distillation
criterion_kd = nn.KLDivLoss(reduction='batchmean') # Use batchmean reduction
# --- Training Loop ---
# for inputs, labels in dataloader:
optimizer.zero_grad()
# Get outputs from student
student_logits = student_model(inputs)
# Get outputs from teacher (no gradients needed)
with torch.no_grad():
teacher_logits = teacher_model(inputs)
# Calculate standard task loss (using student logits and ground-truth labels)
# Note: CrossEntropyLoss expects raw logits
loss_task = criterion_task(student_logits, labels)
# Calculate distillation loss (using softened logits)
# Softmax applied with temperature, then LogSoftmax for KLDivLoss input stability
student_log_probs_soft = F.log_softmax(student_logits / temperature, dim=-1)
teacher_probs_soft = F.softmax(teacher_logits / temperature, dim=-1)
# KLDivLoss expects log-probabilities for the student
# and probabilities for the teacher
# Multiply by T*T as per original Hinton distillation paper scaling
loss_kd = criterion_kd(student_log_probs_soft,
teacher_probs_soft) * (temperature ** 2)
# Combine losses
total_loss = alpha * loss_task + (1 - alpha) * loss_kd
total_loss.backward()
optimizer.step()
# --- End Training Loop ---
print(
f"Task Loss: {loss_task.item():.4f}, "
f"KD Loss: {loss_kd.item():.4f}, "
f"Total Loss: {total_loss.item():.4f}"
)
This code snippet shows the core logic for combining the standard cross-entropy loss with the KL divergence-based distillation loss using softened outputs. The
temperature
andalpha
parameters control the distillation process.
The student model architecture is typically chosen to be significantly smaller and faster than the teacher. This might involve:
The student does not need to be a strict subset of the teacher's architecture. Different architectural choices can be explored, as long as the student has sufficient capacity to learn the distilled knowledge effectively for the target task.
Advantages:
Disadvantages and Considerations:
Knowledge distillation provides a powerful tool for creating smaller, more efficient language models that retain much of the predictive power of their larger counterparts, making it a valuable technique in the LLM compression toolkit alongside quantization and pruning.
© 2025 ApX Machine Learning