尽管近端策略优化(PPO)为我们基于训练好的奖励模型(RM)的奖励信号更新语言模型提供了一种方法,但简单地应用它可能导致不好的结果。RL策略($\pi_{RL}$)可能会找到一些序列,这些序列从RM那里获得高分,但却不自然、重复、无意义,或风格与原始监督微调(SFT)模型($\pi_{SFT}$)差异很大。这种现象有时被称为“奖励欺骗”或“模式崩塌”,即模型过度优化代理奖励信号,从而失去在预训练和SFT期间学到的通用语言能力。为此,标准的RLHF流程引入了一个惩罚项到PPO目标函数中,该惩罚项基于RL策略的输出分布与SFT模型输出分布之间的Kullback-Leibler(KL)散度。让策略保持稳定KL散度,表示为 $KL(P || Q)$,衡量一个概率分布 $P$ 与参考概率分布 $Q$ 的差异程度。KL散度为零表示分布完全相同。在RLHF的背景下,我们希望衡量对于给定输入提示 $x$ 和可能的输出标记 $y$,当前RL策略 $\pi_{RL}(y|x)$ 与参考SFT策略 $\pi_{SFT}(y|x)$ 偏离了多少。目标不只是最大化来自RM的奖励,而是要在不 偏离太远 于表现良好的SFT模型行为的前提下进行。我们将这一约束直接纳入PPO使用的奖励信号中。给定提示 $x$,生成序列 $y = (y_1, ..., y_T)$ 的修改后奖励变为:$$ \text{总奖励}(x, y) = R_{RM}(x, y) - \beta \cdot KL(\pi_{RL}(y|x) || \pi_{SFT}(y|x)) $$这里:$R_{RM}(x, y)$ 是奖励模型分配给生成的序列 $(x, y)$ 的最终奖励。$KL(\pi_{RL}(y|x) || \pi_{SFT}(y|x))$ 是RL策略和SFT策略分配的序列概率分布之间的KL散度。在实际操作中,这通常在生成过程中按每个标记进行近似或计算。$\beta$ 是一个超参数,用于控制KL惩罚的强度。更高的 $\beta$ 值会对偏离SFT模型的行为施加更强的惩罚,使 $\pi_{RL}$ 保持接近 $\pi_{SFT}$。更低的 $\beta$ 值允许RL策略有更大的自由度来生成最大化RM分数的输出,即使它们与SFT模型的输出有明显差异。实践中计算KL惩罚计算整个序列的精确KL散度通常是不可行的。相反,PPO通常在标记级别上操作。在PPO的生成(rollout)阶段,对于在给定上下文 $x, y_{<t}$ 情况下,在步骤 $t$ 生成的每个标记 $y_t$,我们计算:该标记根据当前RL策略的对数概率:$\log \pi_{RL}(y_t | x, y_{<t})$同一标记根据参考SFT策略的对数概率:$\log \pi_{SFT}(y_t | x, y_{<t})$按标记的KL惩罚项通常通过这些对数概率的差值来近似:$$ \text{按标记的KL惩罚近似} \approx \log \pi_{RL}(y_t | x, y_{<t}) - \log \pi_{SFT}(y_t | x, y_{<t}) $$该项在计算优势和更新策略之前,从RM提供的按标记奖励中减去(乘以 $\beta$ 后)。参考SFT模型的参数在此RL阶段保持冻结;它只执行前向传播来提供参考对数概率。让我们看一个PyTorch代码片段,说明如何为一批生成的序列计算此惩罚。假设我们有来自两个模型在每个位置上词汇表中每个标记的对数概率。import torch import torch.nn.functional as F # 假设 log_probs_rl 和 log_probs_sft 是模型的对数softmax输出 # 来自模型 # 形状: [批大小, 序列长度, 词汇表大小] # actions: RL策略在rollout期间实际生成的标记索引 # # 形状: [批大小, 序列长度] def calculate_kl_penalty(log_probs_rl, log_probs_sft, actions, beta): """ 计算RLHF PPO中使用的KL惩罚项。 参数: log_probs_rl: 来自RL策略模型的对数概率。 log_probs_sft: 来自参考SFT模型的对数概率 (已分离)。 actions: RL策略生成的标记ID。 beta: KL惩罚系数。 返回: kl_penalty: 每个标记位置的KL惩罚张量。 形状: [批大小, 序列长度] """ # 确保 log_probs_sft 不参与梯度计算 log_probs_sft_detached = log_probs_sft.detach() # 收集所采取特定动作的对数概率 log_prob_rl_taken = torch.gather( log_probs_rl, 2, actions.unsqueeze(2) ).squeeze(2) log_prob_sft_taken = torch.gather( log_probs_sft_detached, 2, actions.unsqueeze(2) ).squeeze(2) # 计算近似的按标记KL散度(对数比) log_ratio = log_prob_rl_taken - log_prob_sft_taken # 按 beta 进行缩放 kl_penalty = beta * log_ratio return kl_penalty # --- 示例用法 --- # batch_size, seq_len, vocab_size = 4, 50, 32000 # log_probs_rl = torch.randn( # batch_size, seq_len, vocab_size # ).log_softmax(dim=-1) # log_probs_sft = torch.randn( # batch_size, seq_len, vocab_size # ).log_softmax(dim=-1) # actions = torch.randint(0, vocab_size, (batch_size, seq_len)) # beta_kl = 0.1 # penalty = calculate_kl_penalty( # log_probs_rl, log_probs_sft, actions, beta_kl # ) # print(f"KL Penalty Shape: {penalty.shape}") # # 输出: KL Penalty Shape: torch.Size([4, 50]) # 这个“惩罚”通常会从RM提供的奖励中减去 # 在PPO算法中计算优势之前。 # reward_signal = reward_from_rm - penalty此代码片段说明了将 所采取动作 的对数概率比用作KL惩罚项的常见做法。它比在每一步计算整个词汇表上的完整KL散度计算成本更低,但其作用相同,都是为了惩罚偏离行为。权衡考量KL惩罚作为一个重要的正则化项。它防止RL策略陷入过度依赖奖励模型而牺牲语言质量的狭窄模式。通过将RL策略与SFT策略关联起来,这有助于确保模型保留其流畅性、连贯性和其中包含的广泛知识。选择合适的 $\beta$ 值很重要。如果 $\beta$ 过高,RL训练可能会受到过度限制,阻止策略基于奖励信号明显改进。模型将非常紧密地遵循原始SFT行为。如果 $\beta$ 过低,策略可能会偏离过多,可能导致奖励欺骗和输出质量下降,即使RM分数很高。找到一个合适的 $\beta$ 通常需要凭经验调整,并仔细评估最大化RM分数与保持理想语言特性之间的权衡。digraph RLHF_Objective { rankdir=TD; node [shape=box, style=rounded, fontname="Arial", color="#495057", fillcolor="#e9ecef", style="filled,rounded", fontsize=12]; edge [color="#adb5bd", fontsize=12]; Objective [label="PPO每标记目标", shape=hexagon, fillcolor="#748ffc", fontcolor="white"]; Maximize [label="最大化", shape=plaintext]; Reward [label="来自RM的奖励\n(越高越好)", fillcolor="#b2f2bb"]; Penalty [label="KL惩罚项\n(越低越好)", fillcolor="#ffc9c9"]; LogRatio [label="log \u03c0_RL(y|x) - log \u03c0_SFT(y|x)", fillcolor="#ffec99"]; Beta [label="\u03b2 (权重)", shape=ellipse, fillcolor="#ced4da"]; Objective -> Maximize [arrowhead=none]; Maximize -> Reward; Maximize -> Penalty [label="-"]; Penalty -> LogRatio; Penalty -> Beta [label="*"]; } PPO目标旨在最大化来自奖励模型的奖励,同时最小化通过 \u03b2 缩放的KL惩罚,这会阻止偏离原始SFT模型的行为。总之,KL散度惩罚是RLHF中的一种标准方法,用于稳定训练并防止语言模型过度偏离其初始SFT状态。它鼓励模型根据人类反馈(由RM捕获)找到受青睐的输出,同时保持其在之前训练阶段学到的流畅性和通用能力。