趋近智
大师班
尽管使用近端策略优化 (PPO) 的人类反馈强化学习 (RLHF) 在使语言模型与人类偏好保持一致方面已被证明有效,但此过程可能复杂且有时不稳定。它通常包含多个阶段:监督微调 (SFT)、基于人类偏好数据训练一个独立的奖励模型 (RM),然后使用由 RM 指导的强化学习对 SFT 模型进行微调。这种多阶段流程引入了多个超参数和潜在的失败点,特别是在强化学习阶段,该阶段对实现细节敏感,且容易出现奖励作弊等问题。
直接偏好优化 (DPO) 提供了一种更简化的偏好对齐方法,绕过了显式奖励模型训练的需求,并彻底避免了强化学习的复杂性。DPO 使用一个简单的分类目标,直接基于偏好数据优化语言模型。
DPO 的核心思路源于 RLHF 寻求的最优策略与潜在(隐式)奖励函数之间的数学关系。回想一下,标准奖励建模通常使用 Bradley-Terry 模型,将提示 x 的成对偏好 (yw,yl)(其中 yw 优于 yl)与潜在奖励函数 r 关联起来:
P(yw≻yl∣x)=σ(r(x,yw)−r(x,yl))
此处,σ 是 logistic 函数。RLHF 的目标是找到一个策略 πRL,在使预期奖励 E[r(x,y)] 最大化的同时,通过 KL 散度惩罚,保持与参考策略 πref(通常是 SFT 模型)的接近度:
maxπRLE(x,y)∼πRL[r(x,y)]−βDKL(πRL(y∣x)∣∣πref(y∣x))
DPO 使用此约束优化问题的解析解。可以证明,最优策略 πRL 具有以下形式:
πRL(y∣x)=Z(x)1πref(y∣x)exp(β1r(x,y))
其中 Z(x) 是一个配分函数,确保概率和为一。此方程连接了最优策略、参考策略和奖励函数。通过将此关系代回 Bradley-Terry 偏好模型,可以消除奖励函数 r(x,y),直接用策略表示偏好概率:
P(yw≻yl∣x)=σ(βlogπref(yw∣x)πRL(yw∣x)−βlogπref(yl∣x)πRL(yl∣x))
DPO 直接训练语言模型 πθ 来满足此偏好模型,其中 πθ 代替了未知的最优策略 πRL。其目标是最大化在此模型下观测到的人类偏好的对数似然。这从而得到 DPO 损失函数:
LDPO(πθ;πref)=−E(x,yw,yl)∼D[logσ(βlogπref(yw∣x)πθ(yw∣x)−βlogπref(yl∣x)πθ(yl∣x))]
此处,D 是偏好三元组 (x,yw,yl) 的数据集,πθ 是正在优化的语言模型,πref 是冻结的参考 SFT 模型,β 是一个超参数,它隐式控制优化策略 πθ 与参考策略 πref 的偏离程度。更高的 β 会赋予偏好数据更大的权重,可能导致更大的偏离。
实现 DPO 比 RLHF 显著更简单。它需要:
训练循环涉及标准监督学习:
下面是一个 PyTorch 代码片段,用于在训练步骤中计算 DPO 损失的核心部分,假设您已获得对数概率:
import torch
import torch.nn.functional as F
# 假设 log_probs_policy 和 log_probs_ref 包含对数概率
# 形状: (batch_size,),适用于选中 (w) 和拒绝 (l) 响应
# log_probs_policy_w, log_probs_policy_l: 来自正在
# 训练的模型 (pi_theta) 的对数概率
# log_probs_ref_w, log_probs_ref_l: 来自冻结的参考
# 模型 (pi_ref) 的对数概率
# beta: 超参数(例如 0.1)
def dpo_loss(log_probs_policy_w, log_probs_policy_l,
log_probs_ref_w, log_probs_ref_l, beta):
"""计算一批偏好的 DPO 损失。"""
# 计算对数比率
log_ratio_w = log_probs_policy_w - log_probs_ref_w
log_ratio_l = log_probs_policy_l - log_probs_ref_l
# 对数比率差值乘以 beta
diff_scaled = beta * (log_ratio_w - log_ratio_l)
# 使用 logistic sigmoid 计算损失
loss = -F.logsigmoid(diff_scaled)
# 对批次内的损失取平均
return loss.mean()
# 在训练循环中的示例用法
# batch = get_preference_batch()
# # (提示, 选中响应, 拒绝响应)
# outputs_policy = policy_model(prompts,
# chosen_responses,
# rejected_responses)
# with torch.no_grad():
# outputs_ref = ref_model(prompts,
# chosen_responses,
# rejected_responses)
#
# # 提取对数概率(细节取决于模型实现)
# log_probs_policy_w = get_log_probs(outputs_policy, chosen_responses)
# log_probs_policy_l = get_log_probs(outputs_policy, rejected_responses)
# log_probs_ref_w = get_log_probs(outputs_ref, chosen_responses)
# log_probs_ref_l = get_log_probs(outputs_ref, rejected_responses)
#
# loss = dpo_loss(log_probs_policy_w, log_probs_policy_l,
# log_probs_ref_w, log_probs_ref_l, beta=0.1)
# loss.backward()
# optimizer.step()
优点:
缺点:
总之,DPO 为标准 RLHF 流程提供了一种更简单的替代方案。其稳定性及易于实现使其成为一个有吸引力的选择,尤其是在计算资源或强化学习专业知识有限的情况下,它能有效地使语言模型与人类偏好保持一致。这代表了使大型语言模型更具帮助性、诚实性和无害性的过程中的显著简化。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造