趋近智
基于本章前面介绍的原理,我们现在转向实际操作,进行知识蒸馏,将大型生成式语言模型(教师模型)的知识迁移到一个更小、更高效的模型(学生模型)中。目标是创建一个学生模型,使其在明显更小、更快的同时,仍能保留教师模型的大部分生成能力。本节提供一个分步指南,侧重于与生成模型相关的实施细节和评估方法。
在启动蒸馏过程前,需要进行细致的准备。
GPT-3.5、LLaMA-7B 这样的模型,甚至是针对特定风格或领域微调的版本。对于本实践指南,我们假设正在从 TeacherLM-7B(70亿参数)模型蒸馏知识。获取教师模型需要加载其权重和架构,这通常使用像 Hugging Face Transformers 这样的库。StudentLM-1B(10亿参数)。重要的一点是,学生模型的架构应与教师模型的输出格式兼容(例如,两者都在相同的词汇表上生成 logits)。虽然层数和隐藏维度会有所不同,但核心生成机制(例如,Transformer 解码器)应该相似。transformers、datasets,并可能包括 accelerate 以实现高效训练。# 配置
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
# 加载教师模型 (确保其处于评估模式且不需要梯度)
teacher_model_name = "path/to/large/teacher/model"
teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model_name)
teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model_name)
teacher_model.eval()
for param in teacher_model.parameters():
param.requires_grad = False
# 加载或定义学生模型
student_model_name = "path/to/smaller/student/config_or_model" # 或定义架构
student_tokenizer = AutoTokenizer.from_pretrained(student_model_name) # 通常与教师模型相同
student_model = AutoModelForCausalLM.from_pretrained(student_model_name) # 或从配置初始化
# 加载数据集
# dataset = load_dataset(...)
知识蒸馏的根本在于指导学生模型训练的损失函数。
如前面介绍,一种常见方法是将标准语言建模损失(如果存在真实目标)与知识蒸馏损失相结合,以促使学生模型模仿教师模型的输出分布。
软标签 (KL 散度): 主要的知识蒸馏损失旨在最小化教师模型和学生模型在词汇表上的概率分布之间的 KL 散度。应用温度缩放来软化分布,防止模型对单个标记过度自信,并提供更丰富的监督信号。
单个标记预测的损失为:
LKD=T2⋅DKL(σ(zS/T)∣∣σ(zT/T))其中 zS 和 zT 分别是学生模型和教师模型生成的 logits,T 是温度(通常 T>1),σ 表示 softmax 函数。将此损失在序列长度和批次上平均,得到最终的知识蒸馏损失部分。
硬标签 (交叉熵): 如果蒸馏数据集包含真实下一个标记(例如,在继续预训练或微调期间),标准交叉熵损失 (LCE) 可以与知识蒸馏损失一同使用。这使学生模型立足于实际任务数据。
LCE=−i∑yilog(σ(zS)i)其中 yi 是第 i 个标记的独热编码真实标签。
组合损失: 最终损失函数通常是交叉熵损失(如果使用)和 KL 散度损失的加权和:
LTotal=(1−α)LCE+αLKD这里,α 是一个超参数(介于 0 和 1 之间),用于平衡真实标签和教师模型软标签的影响。选择合适的 α 和温度 T 通常需要进行实验。
为了更深层的知识迁移,特别是在架构存在差异时,匹配中间表示可以带来益处。
加入这些中间损失会增加复杂性,但可以明显提高学生模型对教师模型学到模式的理解。总损失变为 LCE、LKD 以及任何中间匹配损失的加权和。
训练循环需要修改以适应教师模型和自定义损失函数。
对于每个输入批次:
虽然自定义训练循环提供最大灵活性,但 Hugging Face Trainer 可以子类化以加入蒸馏功能。
# Trainer 的子类
from transformers import Trainer
import torch.nn.functional as F
import torch.nn as nn
class DistillationTrainer(Trainer):
def __init__(self, *args, teacher_model=None, temperature=2.0, alpha=0.5, **kwargs):
super().__init__(*args, **kwargs)
self.teacher_model = teacher_model
self.teacher_model.to(self.args.device) # 确保教师模型在同一设备上
self.temperature = temperature
self.alpha = alpha
# 如果需要匹配不同大小的隐藏状态,可能在此处添加投影层
def compute_loss(self, model, inputs, return_outputs=False):
# 学生模型前向传播 (Trainer 的标准行为)
student_outputs = model(**inputs)
student_logits = student_outputs.logits
# 如果提供了标签,计算标准交叉熵损失
if "labels" in inputs:
loss_ce = student_outputs.loss # Trainer 默认计算此项
else:
loss_ce = 0.0 # 或适当处理没有标签的情况
# 教师模型前向传播 (无梯度)
with torch.no_grad():
teacher_outputs = self.teacher_model(**inputs)
teacher_logits = teacher_outputs.logits
# 计算知识蒸馏损失 (确保因果语言模型的正确切片/对齐)
# 通常比较预测标记的 logits (移动 logits 和标签)
vocab_size = student_logits.size(-1)
student_log_probs = F.log_softmax(student_logits[:, :-1, :] / self.temperature, dim=-1)
teacher_probs = F.softmax(teacher_logits[:, :-1, :] / self.temperature, dim=-1)
# KLDivLoss 期望 log-probs 作为输入,probs 作为目标
loss_kd = nn.KLDivLoss(reduction="batchmean")(student_log_probs, teacher_probs) * (self.temperature ** 2)
# 组合损失
# 如果使用标签:
# loss = (1.0 - self.alpha) * loss_ce + self.alpha * loss_kd
# 如果不使用标签 (纯粹从教师信号蒸馏):
loss = loss_kd # 如果需要,调整 alpha 逻辑
# 如果适用,在此处添加隐藏状态匹配损失
# loss_hidden = compute_hidden_state_loss(...)
# loss += beta * loss_hidden
return (loss, student_outputs) if return_outputs else loss
# 设置训练参数
# training_args = TrainingArguments(...)
# 实例化 DistillationTrainer
# trainer = DistillationTrainer(
# model=student_model,
# teacher_model=teacher_model,
# args=training_args,
# train_dataset=tokenized_dataset["train"],
# eval_dataset=tokenized_dataset["validation"],
# tokenizer=student_tokenizer,
# # data_collator=... # 对因果语言模型的填充和标签移动很重要
# temperature=2.0,
# alpha=0.5,
# )
# 开始训练
# trainer.train()
注意: 这段代码是。在损失计算中实现因果语言模型的标签移动和正确处理填充需要细致注意细节。
全面评估对于确认蒸馏过程的成功非常重要。
教师模型、蒸馏学生模型和从零开始训练的学生模型在下游任务(例如摘要)上的性能与模型大小的比较。蒸馏后的学生模型以明显更少的参数接近教师模型的性能。
评估生成文本的质量:
人工评估或与教师模型输出进行并排比较通常对于全面评估是必要的。
总是将蒸馏后的学生模型与相关基线进行比较:
本动手指南提供了生成式大型语言模型蒸馏的基本步骤。成功需要对架构、数据、损失函数和超参数进行细致的实验,并在定量和定性两个方面进行严格评估的指导。结果,如果成功,是一个明显更高效的模型,适合在资源受限的环境中部署。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造