趋近智
大师班
预训练语言模型尽管其表现出色的能力,但通常缺乏可靠地遵循用户指令或遵守期望行为准则所需的特定调整。它们通过海量文本数据训练,以预测序列中的下一个词元,但这个目标不能直接转化为人类期望中定义的帮助性、诚实性或无害性。监督微调(SFT)是一种旨在弥补这一差距的技术,它通过明确地教导模型如何以更受偏好的方式回应提示。
SFT 借助由精选输入提示及其对应期望输出组成的数据集,来调整预训练的大语言模型。可以将其视为直接为模型提供它应该如何表现的示例。模型不再从网络规模文本的隐含模式中学习,而是从优秀回复的明确示范中学习。这个过程包括在这些监督示例上进一步训练预训练模型,通常采用交叉熵等标准序列到序列损失函数。
其核心是,SFT 通过使模型生成输出与微调数据集中提供的目标输出之间的差异最小化,从而优化模型的参数。这个过程通常遵循以下步骤:
<|prompt|> 马来西亚的首都在哪里?<|response|> 马来西亚的首都是吉隆坡。<|endoftext|>。desired_response部分的词元计算。提示词元作为上下文,但不直接参与损失计算或梯度更新。这种有针对性的损失计算非常重要。我们希望模型学习如何根据提示生成回复,而不是简单地预测提示词元本身(这部分它在预训练期间已经学习过)。
考虑目标函数。在预训练期间,模型最大化整个文本语料库的似然,即 P(ext文本)。在SFT中,模型学习一个条件概率:给定特定提示,它最大化期望回复的似然,即 P(ext回复∣ext提示)。这种转变使模型专注于根据指令输入生成适当的输出。
我们可以用图示呈现单个SFT训练步骤中信息的基本流向:
SFT 过程的简化示意图,显示了如何使用数据集中的提示和期望回复来计算损失并更新预训练大语言模型的权重。
为了在实际操作中,使用 PyTorch 等框架实现有针对性的损失计算,我们通常会创建一个损失掩码。这个掩码确保只有与期望回复对应的词元才参与损失计算。
这是一个 PyTorch 代码片段,说明了这一点:
import torch
import torch.nn.functional as F
# 假设:
# - logits: 模型输出的对数几率 [batch_size, sequence_length, vocab_size]
# - labels: 目标词元ID [batch_size, sequence_length]
# - prompt_lengths: 批次中每个项目的提示部分长度 [batch_size]
# - IGNORE_INDEX: 损失函数会忽略的特殊索引(例如 -100)
# 假设 IGNORE_INDEX 全局定义,例如:
IGNORE_INDEX = -100
def calculate_sft_loss(logits, labels, prompt_lengths):
"""仅针对回复词元计算交叉熵损失。"""
batch_size, sequence_length, vocab_size = logits.shape
# 为下一个词元预测平移对数几率和标签
# 预测词元 i 的对数几率位于索引 i-1
# 词元 i 的标签位于索引 i
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# 创建一个损失掩码
# 初始化掩码为 1(计算损失)
loss_mask = torch.ones_like(shift_labels, dtype=torch.bool)
# 将提示词元的掩码设置为 0(忽略损失)
for i in range(batch_size):
# 提示长度包含初始词元,
# 因此在平移序列中,掩码范围直到 prompt_length - 1
prompt_end_index = prompt_lengths[i] - 1
if prompt_end_index > 0: # 确保存在要掩码的提示部分
loss_mask[i, :prompt_end_index] = 0
# 应用掩码:掩码为 0 的地方,将标签设置为 IGNORE_INDEX
masked_labels = shift_labels.masked_fill(~loss_mask, IGNORE_INDEX)
# 展平序列维度以计算损失
shift_logits = shift_logits.view(-1, vocab_size)
masked_labels = masked_labels.view(-1)
# 计算交叉熵损失,忽略 IGNORE_INDEX
loss = F.cross_entropy(shift_logits,
masked_labels,
ignore_index=IGNORE_INDEX)
return loss
# --- 示例用法 ---
# batch_size = 2
# sequence_length = 10
# vocab_size = 1000
# prompt_lengths = torch.tensor([3, 5]) # 每个项目的提示长度
# dummy_logits = torch.randn(batch_size,
# sequence_length,
# vocab_size,
# requires_grad=True)
# dummy_labels = torch.randint(0, vocab_size, (batch_size, sequence_length))
# sft_loss = calculate_sft_loss(dummy_logits, dummy_labels, prompt_lengths)
# print(f"Calculated SFT Loss: {sft_loss.item()}")
# sft_loss.backward() # 计算梯度
这个代码片段概述了如何掩盖损失计算,确保在SFT期间只有回复词元影响模型更新。
监督微调有几个重要的对齐目标:
虽然 SFT 在教导模型何种回复是基于示例所期望的方面是有效的,但它本身不能完美地捕获人类偏好。它教导模型模仿所提供回复的风格和内容。对于更复杂的对齐目标,例如判断多个合理回复之间的相对质量,或优化“帮助性”等特质,SFT 之后通常会采用来自人类反馈的强化学习(RLHF)等技术,我们将在下一章讨论。SFT 提供了一个基础,使模型具备遵循指令的基本能力,之后再使用基于偏好的方法进行进一步的优化。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造