趋近智
知识蒸馏 (knowledge distillation)提供了一种独特的模型压缩方法。它不是通过量化 (quantization)或剪枝直接修改大型模型的参数 (parameter),而是侧重于将知识从一个大型预训练 (pre-training)模型(“教师”模型)转移到一个更小、更高效的模型(“学生”模型)。目的是训练一个学生模型,使其通过从教师模型提供的更丰富的输出信号中学习,从而比从头开始使用相同架构训练的模型获得明显更好的性能。这使得学生模型更适合计算资源有限或对延迟有严格要求的部署环境。
其主要思想在于观察到,一个训练有素的大型模型能从数据中捕获复杂的模式和细微特征,这些不仅体现在其最终预测中,也体现在其内部表示和输出概率分布中。学生模型通过最小化一个损失函数 (loss function)来学习,该函数促使它模仿教师模型的这些行为,同时还使用真实标签来学习原始任务目标。
有几种策略可用于将知识从教师模型迁移到学生模型:
匹配输出Logits(软目标): 这是知识蒸馏 (knowledge distillation)最常见的形式。学生模型不仅单独使用硬真实标签(例如,one-hot编码向量 (vector))进行训练,还被训练去匹配教师模型在可能输出类别或token上产生的概率分布。为了提供更丰富的学习信号,教师模型和学生模型的输出通常使用softmax函数中的温度参数 (parameter)()进行“软化”:
这里, 表示类别 的logits。较高的温度()会生成一个更柔和的类别概率分布,从而显示更多关于教师模型内部“置信度”以及类别之间相似结构的信息。温度 对应于标准softmax。蒸馏损失通常是教师模型()和学生模型()软化概率分布之间的Kullback-Leibler(KL)散度或均方误差(MSE):
或
其中, 和 分别是教师模型和学生模型的 logits(softmax之前的输出)。直接使用logits与MSE有时会更简单,并且同样有效。
匹配中间特征: 知识也可以通过鼓励学生模型复制教师模型中间层的激活或隐藏状态来迁移。这迫使学生学习相似的内部表示。损失函数 (loss function)(通常是MSE)计算教师模型和学生模型相应层特征图之间的差异。
这里, 和 表示输入 在教师模型和学生模型选定层中的特征激活。这里的挑战是层的对齐 (alignment),特别是如果架构差异很大。通常,会学习线性变换,以便在计算损失之前将学生模型的特征映射到教师模型特征的维度。
匹配注意力机制 (attention mechanism): 对于基于Transformer的模型,教师模型学习到的注意力模式包含token之间有价值的关系信息。蒸馏可以包括训练学生模型生成与教师模型相似的注意力图。损失是根据相应层或头的注意力权重 (weight)矩阵之间的差异计算的。
知识蒸馏 (knowledge distillation)的标准训练设置包括:
学生模型使用一个组合损失函数 (loss function)进行训练,该函数是标准任务损失(,例如,针对真实标签的交叉熵损失)和蒸馏损失()的加权和:
超参数 (parameter) (hyperparameter) (通常在0到1之间)平衡了匹配真实数据与模仿教师模型的重要性。用于软化logits的温度 是另一个重要的超参数。
这是一个PyTorch代码片段,展示了使用KL散度进行logits匹配的损失计算逻辑:
import torch
import torch.nn as nn
import torch.nn.functional as F
# 假设 teacher_model 和 student_model 已定义
# 假设输入和标签可从数据加载器获取
teacher_model.eval() # 教师模型处于评估模式且已冻结
student_model.train() # 学生模型处于训练模式
optimizer = torch.optim.AdamW(student_model.parameters(), lr=1e-4)
# 超参数
temperature = 4.0
alpha = 0.3 # 标准任务损失的权重
# 标准交叉熵损失
criterion_task = nn.CrossEntropyLoss()
# 蒸馏的KL散度损失
criterion_kd = nn.KLDivLoss(reduction='batchmean') # 使用 batchmean 归约
# --- 训练循环 ---
# for inputs, labels in dataloader:
optimizer.zero_grad()
# 获取学生模型的输出
student_logits = student_model(inputs)
# 获取教师模型的输出(不需要梯度)
with torch.no_grad():
teacher_logits = teacher_model(inputs)
# 计算标准任务损失(使用学生logits和真实标签)
# 注意:CrossEntropyLoss 期望原始 logits
loss_task = criterion_task(student_logits, labels)
# 计算蒸馏损失(使用软化 logits)
# 使用温度应用 Softmax,然后是 LogSoftmax,以保证 KLDivLoss 输入的稳定性
student_log_probs_soft = F.log_softmax(student_logits / temperature, dim=-1)
teacher_probs_soft = F.softmax(teacher_logits / temperature, dim=-1)
# KLDivLoss 期望学生模型的对数概率
# 以及教师模型的概率
# 根据 Hinton 原始蒸馏论文的缩放,乘以 T*T
loss_kd = criterion_kd(student_log_probs_soft,
teacher_probs_soft) * (temperature ** 2)
# 组合损失
total_loss = alpha * loss_task + (1 - alpha) * loss_kd
total_loss.backward()
optimizer.step()
# --- 训练循环结束 ---
print(
f"任务损失: {loss_task.item():.4f}, "
f"KD 损失: {loss_kd.item():.4f}, "
f"总损失: {total_loss.item():.4f}"
)
此代码片段展示了将标准交叉熵损失与基于KL散度的蒸馏损失结合(使用软化输出)的主要逻辑。
temperature和alpha参数控制蒸馏过程。
学生模型的架构通常选择比教师模型小很多、速度快很多。这可能包括:
学生模型不一定是教师模型的严格子集。可以尝试不同的架构选择,只要学生模型具有足够的容量,能够有效地学习目标任务的蒸馏知识即可。
优点:
缺点与考量:
知识蒸馏 (knowledge distillation)提供了一种有效的方法,用于创建更小、更高效的语言模型,这些模型保留了其大型对应模型的大部分预测能力,使其成为与量化 (quantization)和剪枝并列的LLM压缩工具箱中有价值的方法。
这部分内容有帮助吗?
© 2026 ApX Machine LearningAI伦理与透明度•