知识蒸馏 (KD) 是一种构建高效深度学习模型的方法,与剪枝(通过移除网络部分来减小模型大小)或量化(降低数值精度)等技术不同。这种方法遵循“教师-学生学习”的理念。我们不直接压缩大型模型,而是训练一个更小、更高效的“学生”模型来模仿大型预训练“教师”模型的行为。背后的想法是,大型教师模型尽管复杂,但已习得丰富的表示和决策边界,能捕捉数据分布中不易察觉的信息。知识蒸馏旨在将这种“暗知识”传递给较小的学生模型。
教师-学生方法
在典型的知识蒸馏设置中,您会用到:
- 教师模型: 一个大型、高性能模型(例如 ResNet-101、模型集成或任何复杂的架构),它已经针对该任务进行了训练。该模型提供要传递的“知识”。
- 学生模型: 一个较小、计算成本较低的模型(例如 MobileNet、剪枝网络或教师模型的简化版本),我们希望对其进行训练以进行高效部署。
目标是训练学生模型,使其不仅能预测正确的标签(硬目标),还能匹配教师模型的输出分布(软目标)。
使用软目标传递知识
标准监督训练使用“硬目标”,这通常是表示真实类别的一热编码向量。例如,如果一张图片属于10个类别中的第3类,则硬目标是 [0, 0, 0, 1, 0, 0, 0, 0, 0, 0]。尽管有效,但此目标提供的信息有限;它只告诉模型哪个类别是正确的,而不是模型应如何在不正确的类别之间分配其概率质量。
然而,教师模型会生成更丰富的输出。其最后一层(在softmax激活之前)生成对数几率(logits),zt。对这些对数几率应用标准softmax函数,可以得到每个类别的概率分数 pt。这些概率通常包含有用的信息。例如,教师可能会给正确类别“狗”赋予高概率,但也会给“猫”或“狼”等相关类别赋予小的非零概率。这种分布反映了教师对类别相似性的认识。
知识蒸馏通过使用带有参数温度 T 的修改版softmax函数来实现此目的。标准softmax对应于 T=1。当 T>1 时,概率分布变得“更软”,意味着概率峰值更低,较小的对数几率会获得比 T=1 时更高的概率。这促使学生学习教师模型捕获的类别间关联。
类别 i 的软目标概率 qi 是通过教师模型的对数几率 zt,i 和温度 T 计算的:
qt,i=∑jexp(zt,j/T)exp(zt,i/T)
同样,学生模型生成自己的对数几率 zs,这些对数几率也通过相同的软化softmax函数,以生成软预测 qs:
qs,i=∑jexp(zs,j/T)exp(zs,i/T)
然后训练学生模型以匹配教师生成的这些软目标。
蒸馏损失函数
学生模型的训练目标通常结合两个损失部分:
- 标准交叉熵损失 (LCE): 这是在学生模型的标准预测(使用 T=1 的softmax)和硬真实标签之间计算的。这确保学生仍然能准确预测正确类别。设 ps 是学生模型的标准概率输出(T=1)。
L_{CE} = \text{CrossEntropy}(p_s, \text{hard_targets})
- 蒸馏损失 (LDistill): 此损失衡量学生模型的软预测 (qs) 和教师模型的软目标 (qt) 之间的差异。此损失的常用选择是 Kullback-Leibler (KL) 散度,它衡量两种概率分布之间的差异。有时,软目标之间的均方误差 (MSE) 也被使用。使用KL散度时:
LDistill=T2×KL(qs∣∣qt)
通常会包含 T2 缩放因子,以确保软目标产生的梯度幅值与温度变化时硬目标产生的梯度幅值大致相当。
最终的损失函数是这两个部分的加权和:
LTotal=αLCE+(1−α)LDistill
α 是一个超参数(通常在0到1之间),它平衡了匹配硬目标和匹配教师软目标的重要性。常见的做法是,开始时对蒸馏损失赋予较高的权重,并可能随着时间推移逐渐降低,或者简单地对 α 使用一个固定的较小值(例如0.1),从而在初始阶段赋予教师指导更大的权重。
基础知识蒸馏设置,显示教师生成软目标,学生通过蒸馏损失(比较软预测)和标准交叉熵损失(比较硬预测与真实标签)的组合进行训练。
其他蒸馏形式
虽然匹配最终输出分布是知识蒸馏最常见的形式,但这个想法可以扩展:
- 特征蒸馏(中间提示学习): 学生模型不仅匹配最终输出,还可以训练其模仿教师模型中间层生成的激活或特征图。这促使学生学习相似的内部表示。这通常涉及添加辅助损失项,以最小化特定层中教师和学生特征图之间的差异。
- 注意力蒸馏: 如果教师模型使用注意力机制,可以训练学生模型生成相似的注意力图,指导学生关注输入中同样重要的区域。
- 关系知识蒸馏: 这侧重于传递教师所认为的数据点之间的关联,而非单个点的直接输出。
实际考量与权衡
知识蒸馏是一种有效技术,但其成功取决于几个因素:
- 教师质量: 更好的教师通常能带来更好的学生,但教师无需完美。
- 学生容量: 学生模型必须有足够的容量来学习蒸馏所得的知识。过小的学生可能无法有效模仿教师。
- 温度 (T): 这是一个重要的超参数。更高的值会创建更软的分布,可能展现更多关于教师内部的知识,但也可能稀释信息。典型值范围为2到10,通常通过实验得出。
- 损失权重 (α): 平衡标准损失和蒸馏损失是重要的。最优值取决于任务和所涉及的模型。
- 训练数据: 知识蒸馏通常需要用于训练教师模型的原始训练数据集(或有代表性的子集)。
优势:
- 可以大幅提高小型模型的性能,通常超出其仅通过硬目标训练时的性能。
- 提供了一种将复杂模型或集成模型的知识压缩到单一、可部署模型中的方式。
- 学生模型在推理时的架构独立于教师模型;与传统训练的同等大小学生模型相比,部署时不会增加额外的计算成本。
劣势:
- 需要一个预训练的高性能教师模型,其获取成本可能较高。
- 学生的训练过程更复杂,涉及多个损失项和额外的超参数(T,α)。
- 寻找最优的教师-学生组合、温度和损失权重通常需要大量实验。
总之,知识蒸馏提供了一种有效的机制,用于将大型复杂模型中学到的信息传递给更小、更高效的模型。通过训练学生模型模仿教师模型的输出分布(软目标),同时通常也从真实标签(硬目标)中学习,我们可以创建紧凑的模型,这些模型保留了其大型对应模型的大部分性能优势,使其适合在资源受限的环境中部署。该技术补充了剪枝和量化等其他方法,构成了构建高效深度学习系统的工具集。