整合监督微调(SFT)、奖励模型构建(RM)和强化学习(RL)阶段需要仔细管理操作顺序以及数据和模型在它们之间的流动。与其将这看作三个独立的过程,不如将其视为一个多阶段计算图,其中一个阶段的输出成为下一个阶段的必要输入。
整个RLHF过程通常由于固有的依赖关系而遵循特定顺序:
- 监督微调(SFT): 这是获得基础预训练语言模型后的起点。它使用高质量的演示数据使模型适应所需适用范围或风格。主要输出是SFT模型,通常表示为 πSFT。
- 奖励模型(RM)训练: 此阶段需要偏好数据。这些数据通常通过获取提示,使用 πSFT 模型(或有时是基础模型,或两者混合)生成多个回复,并让人类标注哪个回复更受青睐来生成。输入是提示、生成的回复对(yw,yl 分别代表获胜和失败),以及人类偏好标签。输出是训练好的奖励模型,Rϕ。
- RL微调(PPO): 此阶段使用 πSFT 模型作为要优化的初始策略(π0)。它还需要训练好的奖励模型 Rϕ 来在训练期间提供奖励信号。提示被送入当前策略 πk,其生成回复。这些回复由 Rϕ 评分。PPO算法随后更新策略 πk 以最大化预期奖励,同时一个KL散度项防止更新后的策略 πk+1 偏离参考策略过远(通常 π0=πSFT)。最终输出是对齐后的策略模型,πRLHF。
这种顺序依赖决定了工作流结构。数据产物和模型检查点必须在这些阶段之间正确传递。
数据和模型流
考虑所产生和消耗的产物:
- 基础LLM: SFT的输入。
- SFT演示数据: SFT的输入。
- SFT模型 (πSFT): SFT的输出;奖励模型数据生成的输入;PPO的输入(作为初始策略 π0)。
- 偏好数据的提示: 用于生成奖励模型训练对的输入。
- 偏好数据 (prompt,yw,yl): 奖励模型训练的输入。
- 奖励模型 (Rϕ): 奖励模型训练的输出;PPO的输入(作为奖励函数)。
- RL训练的提示: PPO用于生成轨迹的输入。
- RLHF策略模型 (πRLHF): PPO阶段的最终输出。
管理这些转换非常重要。您需要可靠地保存SFT模型在训练完成后,然后加载它用于生成偏好数据样本(如果需要)和初始化PPO策略的机制。同样,训练好的奖励模型检查点需要保存,然后由PPO训练器加载。
以下是说明依赖关系的图表:
三阶段RLHF管道中的依赖关系和数据流。圆柱体表示模型,笔记形状表示数据集,圆角矩形表示过程。
实现编排
您如何实现这种编排取决于项目的规模和复杂程度:
- 手动执行: 对于初期实验或小型项目,您可能只需为每个阶段运行各自的脚本。您将从一个脚本保存输出模型(例如
sft_model.pt、reward_model.bin),并手动将其指定为下一个阶段脚本中的输入路径。这直接了当,但容易出错且不易重现。
- Shell/Python脚本编写: 常见方法是编写包装脚本(例如使用Bash或Python),按顺序执行每个阶段的命令。这些脚本可以处理阶段之间文件路径或参数的传递,管理目录,并执行基本错误检查。相比纯手动执行,这提高了重现性。
- 工作流编排平台: 对于大规模、生产级别的RLHF训练,专门的工作流引擎变得非常有价值。Kubeflow Pipelines、Apache Airflow、Metaflow或Prefect等工具允许您将整个管道定义为有向无环图(DAG)。
- 优势: 这些平台自动管理依赖关系,处理失败重试,促进独立步骤(如有)的并行执行,提供日志记录和监控,并常与云环境集成以管理计算资源。它们显著提升重现性和操作可靠性。
- 结构: 您通常将每个阶段(SFT、奖励模型训练、PPO)定义为平台框架内的组件或任务。平台随后按正确顺序执行这些任务,管理数据产物(如存储在云存储中的模型检查点)在它们之间的传递。
无论采用哪种方法,在每个主要阶段结束时设置检查点都非常重要。这允许您在某个阶段失败时恢复管道,或者如果您想使用相同的初始SFT或RM结果来试验后续阶段。对模型和数据集与代码一同进行版本控制(使用Git LFS、DVC或MLflow等工具)也对追踪实验和确保整个管道中使用一致的组件有很大帮助。不同阶段可能也有不同的计算要求(例如,与SFT或RM训练相比,PPO通常需要更多的GPU资源,特别是用于actor/critic/RM/参考模型的多个GPU),编排平台可以通过为每个任务分配适当的资源来帮助管理这些需求。