趋近智
尽管近端策略优化(PPO)为我们基于训练好的奖励模型(RM)的奖励信号更新语言模型提供了一种方法,但简单地应用它可能导致不好的结果。RL策略()可能会找到一些序列,这些序列从RM那里获得高分,但却不自然、重复、无意义,或风格与原始监督微调 (fine-tuning)(SFT)模型()差异很大。这种现象有时被称为“奖励欺骗”或“模式崩塌”,即模型过度优化代理奖励信号,从而失去在预训练 (pre-training)和SFT期间学到的通用语言能力。
为此,标准的RLHF流程引入了一个惩罚项到PPO目标函数中,该惩罚项基于RL策略的输出分布与SFT模型输出分布之间的Kullback-Leibler(KL)散度。
KL散度,表示为 ,衡量一个概率分布 与参考概率分布 的差异程度。KL散度为零表示分布完全相同。在RLHF的背景下,我们希望衡量对于给定输入提示 和可能的输出标记 (token) ,当前RL策略 与参考SFT策略 偏离了多少。
目标不只是最大化来自RM的奖励,而是要在不 偏离太远 于表现良好的SFT模型行为的前提下进行。我们将这一约束直接纳入PPO使用的奖励信号中。给定提示 ,生成序列 的修改后奖励变为:
这里:
计算整个序列的精确KL散度通常是不可行的。相反,PPO通常在标记 (token)级别上操作。在PPO的生成(rollout)阶段,对于在给定上下文 (context) 情况下,在步骤 生成的每个标记 ,我们计算:
按标记的KL惩罚项通常通过这些对数概率的差值来近似:
该项在计算优势和更新策略之前,从RM提供的按标记奖励中减去(乘以 后)。参考SFT模型的参数 (parameter)在此RL阶段保持冻结;它只执行前向传播来提供参考对数概率。
让我们看一个PyTorch代码片段,说明如何为一批生成的序列计算此惩罚。假设我们有来自两个模型在每个位置上词汇表 (vocabulary)中每个标记的对数概率。
import torch
import torch.nn.functional as F
# 假设 log_probs_rl 和 log_probs_sft 是模型的对数softmax输出
# 来自模型
# 形状: [批大小, 序列长度, 词汇表大小]
# actions: RL策略在rollout期间实际生成的标记索引
#
# 形状: [批大小, 序列长度]
def calculate_kl_penalty(log_probs_rl, log_probs_sft, actions, beta):
"""
计算RLHF PPO中使用的KL惩罚项。
参数:
log_probs_rl: 来自RL策略模型的对数概率。
log_probs_sft: 来自参考SFT模型的对数概率
(已分离)。
actions: RL策略生成的标记ID。
beta: KL惩罚系数。
返回:
kl_penalty: 每个标记位置的KL惩罚张量。
形状: [批大小, 序列长度]
"""
# 确保 log_probs_sft 不参与梯度计算
log_probs_sft_detached = log_probs_sft.detach()
# 收集所采取特定动作的对数概率
log_prob_rl_taken = torch.gather(
log_probs_rl, 2, actions.unsqueeze(2)
).squeeze(2)
log_prob_sft_taken = torch.gather(
log_probs_sft_detached, 2, actions.unsqueeze(2)
).squeeze(2)
# 计算近似的按标记KL散度(对数比)
log_ratio = log_prob_rl_taken - log_prob_sft_taken
# 按 beta 进行缩放
kl_penalty = beta * log_ratio
return kl_penalty
# --- 示例用法 ---
# batch_size, seq_len, vocab_size = 4, 50, 32000
# log_probs_rl = torch.randn(
# batch_size, seq_len, vocab_size
# ).log_softmax(dim=-1)
# log_probs_sft = torch.randn(
# batch_size, seq_len, vocab_size
# ).log_softmax(dim=-1)
# actions = torch.randint(0, vocab_size, (batch_size, seq_len))
# beta_kl = 0.1
# penalty = calculate_kl_penalty(
# log_probs_rl, log_probs_sft, actions, beta_kl
# )
# print(f"KL Penalty Shape: {penalty.shape}")
# # 输出: KL Penalty Shape: torch.Size([4, 50])
# 这个“惩罚”通常会从RM提供的奖励中减去
# 在PPO算法中计算优势之前。
# reward_signal = reward_from_rm - penalty
此代码片段说明了将 所采取动作 的对数概率比用作KL惩罚项的常见做法。它比在每一步计算整个词汇表上的完整KL散度计算成本更低,但其作用相同,都是为了惩罚偏离行为。
KL惩罚作为一个重要的正则化 (regularization)项。它防止RL策略陷入过度依赖奖励模型而牺牲语言质量的狭窄模式。通过将RL策略与SFT策略关联起来,这有助于确保模型保留其流畅性、连贯性和其中包含的广泛知识。
选择合适的 值很重要。
找到一个合适的 通常需要凭经验调整,并仔细评估最大化RM分数与保持理想语言特性之间的权衡。
PPO目标旨在最大化来自奖励模型的奖励,同时最小化通过 \u03b2 缩放的KL惩罚,这会阻止偏离原始SFT模型的行为。
总之,KL散度惩罚是RLHF中的一种标准方法,用于稳定训练并防止语言模型过度偏离其初始SFT状态。它鼓励模型根据人类反馈(由RM捕获)找到受青睐的输出,同时保持其在之前训练阶段学到的流畅性和通用能力。
这部分内容有帮助吗?
© 2026 ApX Machine LearningAI伦理与透明度•