The core challenge in knowledge distillation (KD) is defining precisely what knowledge to transfer from the teacher to the student and how to measure the success of this transfer. The mechanism for this is the distillation objective or loss function. It quantifies the discrepancy between the teacher's behavior and the student's behavior, guiding the student's training process. While the original KD formulation focused on matching output distributions, several sophisticated objectives have been developed to capture richer information embedded within the teacher model.
The most fundamental KD objective, proposed by Hinton et al. (2015), involves training the student to mimic the teacher's output probability distribution over classes or tokens. Simply matching the final predictions (hard labels) isn't sufficient, as the teacher's distribution often contains subtle information about class relationships or token likelihoods – knowledge that is lost when converting to a single, hard prediction.
To extract this richer information, a temperature scaling parameter, T, is introduced into the softmax function applied to the logits (the raw, unnormalized outputs before the final activation).
σ(zi,T)=∑jexp(zj/T)exp(zi/T)Here, zi is the logit for class or token i. A higher temperature (T>1) softens the probability distribution, pushing probabilities closer together and revealing more about the relative similarities the teacher assigns to different outputs. A temperature of T=1 recovers the standard softmax.
The distillation loss, LKD, is typically the Kullback-Leibler (KL) divergence between the student's softened predictions (pS=σ(zS,T)) and the teacher's softened predictions (pT=σ(zT,T)):
LKD=T2⋅DKL(pS∣∣pT)=T2i∑pT(i)logpS(i)pT(i)The T2 scaling factor is important. Because the gradients produced by the softened targets are scaled by 1/T2 relative to the gradients from hard targets, multiplying the KD loss by T2 ensures that the relative contribution of the KD loss during training remains roughly constant even if the temperature T is changed.
This objective encourages the student not just to predict the correct output but to understand why the teacher predicts it, learning the nuanced relationships between outputs. In practice, this LKD is often combined with a standard supervised loss (e.g., cross-entropy LCE) on the true hard labels, using a weighting factor α:
LTotal=αLCE(ytrue,σ(zS,T=1))+(1−α)LKD(σ(zS,T),σ(zT,T))This ensures the student still learns to match the ground truth while benefiting from the teacher's soft targets. Choosing the optimal temperature T and weight α often requires empirical tuning.
While matching output distributions is effective, knowledge isn't solely contained in the final layer. Intermediate layers of deep networks like LLMs learn hierarchical representations that capture syntactic, semantic, and contextual information. Distilling this intermediate knowledge can provide stronger guidance to the student.
The objective here is to minimize the difference between the hidden states or activations of selected intermediate layers in the teacher (hTl) and student (hSl) models. Common loss functions for this include:
Here, Lmatch is the set of layer indices chosen for matching. The functions fS and fT represent optional transformation layers (e.g., linear projections) used to align the dimensions if the student and teacher layers have different hidden sizes.
Key considerations for intermediate matching include:
For Transformer-based LLMs, the self-attention mechanism is a defining component. Attention maps, which represent the weighted relationships between tokens at different positions, encode significant structural and contextual information. Transferring this attention knowledge can help the student learn similar relational patterns.
The Attention Transfer (AT) objective minimizes the difference between attention maps (ATl,ASl) from corresponding layers:
LAttention=l∈Lmatch∑Nh1h=1∑Nh∣∣AS,hl−AT,hl∣∣F2Where Nh is the number of attention heads, and ∣∣⋅∣∣F2 denotes the squared Frobenius norm (sum of squared elements) of the difference between the attention matrices for head h in layer l.
Challenges include:
Rather than directly matching representations element-wise (like MSE) or angle-wise (like Cosine), contrastive objectives focus on learning similarities and dissimilarities. Contrastive Representation Distillation (CRD) aims to teach the student to produce representations that are close to the teacher's representation for the same input (positive pair) but far from the teacher's representations for different inputs (negative pairs).
Conceptually, the loss encourages sim(zS,zT) to be high for the same input, while sim(zS,zT,neg) should be low, where zT,neg are teacher representations for other inputs in the batch or a memory bank. A typical loss function like InfoNCE can be adapted:
LContrastive∝−log∑hT,negexp(sim(fS(hS),fT(hT,neg))/τ)exp(sim(fS(hS),fT(hT))/τ)Here, sim is a similarity function (e.g., dot product or cosine similarity), τ is a temperature parameter controlling the sharpness of the distribution over negative samples, and the sum is over negative teacher representations. fS and fT are again potential projection heads.
Contrastive objectives can be powerful for learning the underlying structure of the teacher's representation space without requiring strict element-wise alignment, potentially offering more flexibility to the student.
Often, the most effective distillation strategy involves combining multiple objectives. Different objectives capture complementary aspects of the teacher's knowledge. For example, one might combine soft target matching with intermediate feature matching and attention transfer:
LTotal=λKDLKD+λInterLIntermediate+λAttnLAttention+…The hyperparameters λi control the relative importance of each objective. Selecting the right combination and tuning these weights is a critical part of designing a successful distillation pipeline and typically involves significant experimentation based on the specific task, teacher-student architecture pairing, and performance metrics.
Knowledge transfer points between teacher and student models, illustrating common distillation objectives: matching output logits (LKD), intermediate hidden states (LIntermediate), and attention maps (LAttention).
Choosing the appropriate objective, or combination of objectives, depends heavily on the specific characteristics of the teacher and student models, the available data, computational budget, and the desired trade-offs between student model size, speed, and fidelity.
© 2025 ApX Machine Learning