Now that we have a reward model (rϕ) capable of scoring language model outputs based on human preferences, we need a mechanism to optimize the language model itself to generate outputs that achieve high scores according to this reward model. This is where Reinforcement Learning (RL) comes into play, and specifically, Proximal Policy Optimization (PPO) has emerged as a standard algorithm for this phase in RLHF.
PPO is a policy gradient algorithm designed to take the biggest possible improvement step on a policy without stepping too far and causing performance collapse. It achieves this by constraining the policy update within a "trust region," preventing drastic changes between policy iterations. This makes it more stable and sample-efficient compared to simpler policy gradient methods, which is particularly important when dealing with large, complex models like LLMs.
Why PPO for RLHF?
While various RL algorithms exist, PPO offers several advantages that make it well-suited for fine-tuning large language models:
- Stability: Its core mechanism, either through KL divergence penalization or objective clipping, limits how much the policy can change in each update step. This prevents the model from diverging wildly or "forgetting" its language capabilities learned during pre-training and SFT.
- Sample Efficiency: Compared to algorithms requiring many off-policy corrections or complex exploration strategies, PPO strikes a reasonable balance between sample efficiency and ease of implementation. While still on-policy (requiring fresh samples from the current policy), its update mechanism allows for multiple epochs of optimization on the collected data batch.
- Implementation Simplicity: Relative to some other advanced RL algorithms, PPO's update rule is more straightforward to implement and debug, especially with libraries like Hugging Face's TRL providing optimized components.
- Compatibility with Large Models: PPO readily integrates with the typical actor-critic architecture, which maps well onto the LLM fine-tuning paradigm.
The RLHF Objective with PPO
The core task in the RL phase is to adjust the parameters (θ) of our language model policy (πθ) to maximize the expected reward given by the reward model, while simultaneously ensuring the policy doesn't deviate excessively from the original supervised fine-tuned (SFT) policy (πSFT). This prevents the model from generating outputs that exploit the reward model (reward hacking) but are nonsensical or stylistically inconsistent.
This dual objective is captured in the PPO optimization process for RLHF. For a given prompt x sampled from a distribution D, and a response y generated by the current policy πθ(y∣x), the objective per prompt-response pair incorporates the reward and a penalty based on the Kullback-Leibler (KL) divergence between the current policy and the reference SFT policy:
Objective(x,y)=rϕ(x,y)−βDKL(πθ(y∣x)∣∣πSFT(y∣x))
Let's break this down:
- rϕ(x,y): This is the scalar reward assigned by our trained reward model to the prompt-response pair (x,y). The goal is to maximize this value.
- DKL(πθ(y∣x)∣∣πSFT(y∣x)): This is the KL divergence between the probability distribution of the response y given prompt x under the current policy πθ and the reference SFT policy πSFT. It measures how much the current policy has diverged from the SFT model for this specific generation. Lower KL divergence means the policies are more similar.
- β: This is a hyperparameter that controls the strength of the KL penalty.
- A high β heavily penalizes deviation from πSFT, keeping the model stylistically close to the SFT version but potentially limiting reward maximization.
- A low β allows the policy to optimize more aggressively for the reward rϕ, but risks deviating significantly from the SFT model, potentially degrading text quality or exploiting the reward model.
The overall PPO objective aims to maximize the expected value of this combined term over trajectories sampled using the current policy πθ. The PPO algorithm then uses its specific mechanisms (like the clipped surrogate objective, which we'll discuss later) to optimize this objective effectively.
Adapting PPO Components for LLMs
In a standard PPO setup, you typically have an actor (the policy) and a critic (the value function). In the RLHF context with LLMs, these map as follows:
- Policy Model (Actor): This is the language model we are fine-tuning (πθ). It takes a prompt x as input (state) and generates a sequence of tokens y (action). We usually start with the SFT model's weights and update them during PPO training.
- Reference Model: This is a frozen copy of the initial SFT model (πSFT). Its role is crucial for calculating the KL divergence penalty. Keeping it fixed provides a stable target distribution to regularize against.
- Reward Model: This is the model trained in the previous stage (rϕ). It takes a prompt x and a generated response y and outputs a scalar reward. It is typically kept frozen during the PPO phase.
- Value Model (Critic): This model (Vψ) takes the prompt x as input and estimates the expected discounted future reward starting from that prompt, under the current policy πθ. It helps reduce the variance of policy gradient estimates by calculating advantage values. Often, the value model is initialized from the reward model's weights or the SFT model's weights, potentially with a different output head. Its parameters (ψ) are updated alongside the policy parameters θ during PPO training.
The diagram below illustrates the high-level interaction during a PPO step in RLHF:
Interaction diagram showing the flow during a PPO step in RLHF. Prompts are fed to the policy model to generate responses. These responses are evaluated by the reward model, value model, and reference model to compute rewards, values, and KL divergence penalties, which are then used to calculate advantages and update the policy and value models via the PPO objective.
The PPO Training Loop in RLHF
At a high level, the PPO fine-tuning process involves iteratively performing these steps:
- Rollout: Sample a batch of prompts (x) from a dataset (often the same used for reward modeling or SFT). For each prompt, generate a response (y) using the current policy πθ. Store the prompt, response, and the token log-probabilities from πθ.
- Evaluation: For each generated prompt-response pair (x,y):
- Calculate the reward r=rϕ(x,y) using the frozen reward model.
- Calculate the KL divergence penalty term. This usually involves getting the log-probabilities of the generated response y under the frozen reference model πSFT and comparing them to the log-probabilities from πθ obtained during the rollout. The reward used for PPO updates is often adjusted to be r−β×KL.
- Estimate the value Vψ(x) using the current value model.
- Advantage Estimation: Compute the advantage estimates for each token generated in the response. This typically uses Generalized Advantage Estimation (GAE), which combines the rewards and value estimates to provide a less noisy signal for policy updates. We will cover GAE in detail in the section "Calculating Advantages and Returns".
- Optimization: Perform multiple epochs of gradient updates on the collected batch of rollout data:
- Update the parameters θ of the policy model πθ using the PPO clipped surrogate objective function, which uses the calculated advantages and the ratio of token probabilities between the current policy and the policy used for the rollout.
- Update the parameters ψ of the value model Vψ by minimizing the difference between its predictions and the actual observed returns (calculated during advantage estimation).
This cycle repeats, gradually improving the policy πθ to generate responses that better align with the preferences captured by the reward model, while the KL penalty and PPO's clipping mechanism keep the training stable and prevent catastrophic forgetting.
Understanding this adaptation of PPO is fundamental. The subsequent sections will go deeper into implementing the policy and value networks, the specifics of the KL divergence penalty, how advantages are computed, and the practicalities of tuning and troubleshooting the PPO phase for LLM alignment.