知识蒸馏 (knowledge distillation)(KD)的主要问题在于准确界定要从教师模型传递给学生模型的是何种知识,以及如何衡量这种传递的成效。实现此目的的机制是蒸馏目标或损失函数 (loss function)。它量化 (quantization)了教师模型行为与学生模型行为之间的差异,以此引导学生模型的训练过程。尽管最初的KD方法主要侧重于匹配输出分布,但目前已发展出多种复杂的训练目标,以获取教师模型中包含的更丰富的信息。
软目标:匹配输出分布
最基本的KD目标由Hinton等人(2015)提出,它包含训练学生模型去模仿教师模型在类别或标记 (token)上的输出概率分布。简单匹配最终预测(硬标签)是不够的,因为教师模型的分布通常包含关于类别关系或标记可能性的细微信息——这些信息在转换为单一硬预测时就会丢失。
为提取这些更丰富的信息,在应用于logits(即最终激活前的原始、未归一化 (normalization)输出)的softmax函数中引入了一个温度缩放参数 (parameter)T。
σ(zi,T)=∑jexp(zj/T)exp(zi/T)
这里,zi 是类别或标记i的logit。较高的温度(T>1)会使概率分布更平滑,使得概率值彼此靠近,从而显示出教师模型分配给不同输出的相对相似性。当温度T=1时,则恢复为标准softmax。
蒸馏损失LKD通常是学生模型的平滑预测(pS=σ(zS,T))与教师模型的平滑预测(pT=σ(zT,T))之间的Kullback-Leibler(KL)散度:
LKD=T2⋅DKL(pS∣∣pT)=T2i∑pT(i)logpS(i)pT(i)
T2缩放因子很重要。因为平滑目标产生的梯度相对于硬目标产生的梯度按1/T2缩放,所以将KD损失乘以T2可以确保即使温度T发生变化,KD损失在训练期间的相对贡献也大致保持不变。
这个目标鼓励学生模型不仅要预测正确的输出,还要理解教师模型为何这样预测,从而习得输出之间的关系。在实践中,这个LKD通常与针对真实硬标签的标准监督损失(例如交叉熵LCE)相结合,使用加权因子α:
LTotal=αLCE(ytrue,σ(zS,T=1))+(1−α)LKD(σ(zS,T),σ(zT,T))
这确保了学生模型在从教师模型的软目标中受益的同时,仍能学习匹配真实标签。选择最佳温度T和权重 (weight)α通常需要经验调整。
中间表示匹配
尽管匹配输出分布是有效的,但知识并非仅包含在最终层中。像LLM这样的深度网络的中间层会学习分层表示,这些表示捕获了语法、语义和上下文 (context)信息。蒸馏这种中间知识可以为学生模型提供更强的指导。
这里的目标是最小化教师模型(hTl)和学生模型(hSl)中选定中间层的隐藏状态或激活之间的差异。为此常用的损失函数 (loss function)包含:
- 均方误差(MSE): 直接最小化欧几里得距离。
LIntermediateMSE=l∈Lmatch∑∣∣fS(hSl)−fT(hTl)∣∣22
- 余弦相似度损失: 最大化余弦相似度,侧重于表示向量 (vector)之间的角度而非其大小。在大小可能显著不同时很有用。
LIntermediateCosine=l∈Lmatch∑(1−cos(fS(hSl),fT(hTl)))
这里,Lmatch是用于匹配的层索引集。函数fS和fT表示可选的转换层(例如线性投影),用于在学生模型和教师模型的层具有不同隐藏大小时对齐 (alignment)维度。
中间匹配的考量包含:
- 层选择: 哪些层包含最有价值的可传递信息?早期层可能捕获基本特征,而更深的层则捕获更抽象的内容。匹配多个层是常见的做法。
- 架构不匹配: 如果学生模型的层数少于教师模型,则需要策略来将教师模型层映射到学生模型层(例如,将学生模型的最后一层与教师模型的最后一层匹配,或使用学习到的投影)。
- 计算成本: 计算这些损失会增加训练过程的开销,尤其是在匹配大型激活张量时。
注意力转移
对于基于Transformer的LLM,自注意力 (self-attention)机制 (attention mechanism)是一个界定性组件。注意力图表示了不同位置标记 (token)之间的加权关系,其编码了重要的结构和上下文 (context)信息。转移这种注意力知识可以帮助学生模型学习相似的关系模式。
注意力转移(AT)目标最小化相应层中的注意力图(ATl,ASl)之间的差异:
LAttention=l∈Lmatch∑Nh1h=1∑Nh∣∣AS,hl−AT,hl∣∣F2
其中Nh是注意力头的数量,∣∣⋅∣∣F2表示层l中头h的注意力矩阵之差的平方Frobenius范数(平方元素之和)。
挑战包含:
- 头部映射: 如果学生模型和教师模型的注意力头数量不同,则需要映射策略(例如,平均教师模型的头部,使用学习到的投影)。
- 计算开销: 存储和比较注意力图会增加内存和计算需求。
- 解释性: 尽管注意力模式提供了信息,但与匹配隐藏状态相比,直接强制学生模型匹配这些模式可能过于严格。
对比学习目标
对比目标侧重于学习相似性和非相似性,而非直接逐元素(如MSE)或逐角度(如余弦)匹配表示。对比表示蒸馏(CRD)旨在训练学生模型,使其为相同输入(正例对)生成的表示接近教师模型的表示,但远离不同输入(负例对)的教师模型表示。
总的来说,该损失鼓励sim(zS,zT)对于相同输入较高,而sim(zS,zT,neg)应较低,其中zT,neg是批次中或记忆库中其他输入的教师模型表示。典型的损失函数 (loss function)如InfoNCE可以被调整:
LContrastive∝−log∑hT,negexp(sim(fS(hS),fT(hT,neg))/τ)exp(sim(fS(hS),fT(hT))/τ)
这里,sim是一个相似性函数(例如点积或余弦相似度),τ是一个温度参数 (parameter),用于控制负样本分布的锐度,求和是针对负教师模型表示进行的。fS和fT再次是可能的投影头部。
对比目标在学习教师模型表示空间的潜在结构方面可以很强大,无需严格的逐元素对齐 (alignment),这可能为学生模型提供更大的灵活性。
组合蒸馏目标
通常,最有效的蒸馏策略包含组合多个目标。不同的目标捕获了教师模型知识的互补方面。例如,可以将软目标匹配与中间特征匹配和注意力转移相结合:
LTotal=λKDLKD+λInterLIntermediate+λAttnLAttention+…
超参数 (parameter) (hyperparameter)λi控制每个目标的相对重要性。选择正确的组合并调整这些权重 (weight)是设计成功蒸馏流程的核心部分,通常需要根据具体任务、教师-学生模型架构配对和性能指标进行大量实验。
教师模型和学生模型之间的知识传递点,显示了常见的蒸馏目标:匹配输出logits(LKD)、中间隐藏状态(LIntermediate)和注意力图(LAttention)。
选择合适的单一目标或目标组合,很大程度上取决于教师模型和学生模型的具体特点、可用数据、计算预算以及学生模型大小、速度和保真度之间所需的权衡。