Once you have trained a reward model rθ(x,y) that captures human preferences (as discussed in "Reward Model Training"), the next step in the Reinforcement Learning from Human Feedback (RLHF) pipeline is to use this model to improve the language model's policy, πϕ(y∣x). The goal is to fine-tune the LLM so that it generates responses y (given a prompt x) that score higher according to the reward model, while still maintaining coherent and helpful language capabilities learned during pre-training and initial fine-tuning (often referred to as the reference policy, πref).
Directly optimizing the LLM using supervised learning on high-reward samples isn't ideal, as it doesn't explore the space of possible outputs efficiently or directly optimize the preference objective. Instead, we frame this as a reinforcement learning problem where the LLM policy πϕ is the agent, the context x is the state, the generated response y is the action (or sequence of actions), and the reward model rθ provides the reward signal.
The Reinforcement Learning Objective for RLHF
The objective is to find policy parameters ϕ that maximize the expected reward from the reward model, penalized by a Kullback-Leibler (KL) divergence term. The KL penalty prevents the policy πϕ from moving too far away from the initial reference policy πref, which helps maintain language quality and prevents the model from finding degenerate solutions that exploit the reward model.
The optimization objective can be formulated as:
maximizeϕJ(ϕ)=Ex∼D,y∼πϕ(y∣x)[rθ(x,y)−βDKL(πϕ(y∣x)∣∣πref(y∣x))]
Let's break down this objective:
- x∼D: Prompts x are sampled from a distribution D (often the same distribution used for collecting preference data).
- y∼πϕ(y∣x): Responses y are generated by the current LLM policy πϕ that we are optimizing.
- rθ(x,y): The scalar reward assigned to the prompt-response pair (x,y) by the trained reward model. This is the primary signal indicating alignment with human preferences.
- DKL(πϕ(y∣x)∣∣πref(y∣x)): The KL divergence between the distribution of responses generated by the current policy πϕ and the reference policy πref. This term measures how much the current policy has diverged from the original one. The reference policy is typically the model before RLHF fine-tuning, such as the supervised fine-tuned (SFT) model.
- β: A hyperparameter that controls the strength of the KL penalty. A higher β discourages deviation from πref, while a lower β allows for more aggressive optimization towards the reward model's preference.
This objective encourages the model to generate outputs that the reward model prefers, but not at the cost of generating text that is statistically very different from what the initial, well-behaved model πref would produce.
Why Proximal Policy Optimization (PPO)?
While standard policy gradient methods like REINFORCE could theoretically be used, they often suffer from high gradient variance and instability, especially when dealing with the large parameter spaces and complex action spaces (sequences of tokens) of LLMs. Updates can be destructive, where a single large gradient step significantly harms the policy's performance.
Proximal Policy Optimization (PPO) has emerged as a popular and effective algorithm for RLHF. It belongs to the family of trust region optimization methods, which aim to improve stability by constraining the amount the policy changes in each update step. PPO achieves this stability without the heavy computational cost of some other trust region methods like TRPO.
The main advantages of PPO in this context include:
- Stability: PPO uses a clipped surrogate objective function that discourages large policy updates, leading to more stable training.
- Sample Efficiency: It allows for multiple epochs of gradient updates on the same batch of sampled data, improving data utilization compared to single-update methods.
- Simplicity: While conceptually sophisticated, its implementation is relatively straightforward compared to alternatives like TRPO.
PPO's Clipped Surrogate Objective
At the heart of PPO is its objective function. Let πϕold be the policy before the update (the policy used to generate the data for the current optimization step). PPO optimizes the policy πϕ based on the following objective:
LCLIP(ϕ)=E(x,y)∼πϕold[min(r(ϕ)A^(x,y),clip(r(ϕ),1−ϵ,1+ϵ)A^(x,y))]
Let's unpack the components relevant to RLHF:
- Probability Ratio: r(ϕ)=πϕold(y∣x)πϕ(y∣x). This ratio measures how likely the response y is under the new policy πϕ compared to the old policy πϕold. In practice, we compute this using log probabilities: exp(logπϕ(y∣x)−logπϕold(y∣x)).
- Advantage Estimate: A^(x,y). This estimates how much better the response y is compared to the average response expected from prompt x. In the context of RLHF, a common approach is to use the KL-penalized reward as the advantage:
A^(x,y)≈rθ(x,y)−βDKL(πϕ(y∣x)∣∣πref(y∣x))
Sometimes, a baseline value function V(x) (estimating the expected reward given prompt x) is subtracted from rθ(x,y) to reduce variance, but simpler implementations often use the reward model score directly (potentially normalized or combined with the KL term). Let's denote the combined reward signal used for optimization as R(x,y)=rθ(x,y)−β(logπϕ(y∣x)−logπref(y∣x)). So, A^(x,y)≈R(x,y).
- Clipping: The clip(r(ϕ),1−ϵ,1+ϵ) function constrains the probability ratio r(ϕ) to stay within the interval [1−ϵ,1+ϵ]. ϵ is a small hyperparameter (e.g., 0.2).
- If the advantage A^ is positive (meaning the response y was better than expected), the objective increases as r(ϕ) increases, encouraging the policy update. However, the min and the upper clipping boundary 1+ϵ prevent r(ϕ) from becoming too large, thus limiting the update step size.
- If the advantage A^ is negative (the response was worse than expected), the objective decreases as r(ϕ) increases. The lower clipping boundary 1−ϵ prevents r(ϕ) from becoming too small, again limiting how much the policy can change in a single step.
This clipping mechanism effectively creates a pessimistic bound on the policy update, ensuring that the new policy doesn't deviate too drastically from the old one in a way that could destabilize training.
The PPO Loop for LLM Fine-Tuning
Applying PPO to fine-tune an LLM involves an iterative process:
The RLHF PPO optimization cycle. Prompts are sampled, responses generated by the current policy, then scored by both the reward model and the reference policy to calculate a KL-penalized reward. This data batch is used to update the policy via the PPO algorithm.
- Sample Prompts: Draw a batch of prompts x from the dataset D.
- Generate Responses: For each prompt x, generate a response y using the current policy πϕ. Store the generated sequences (x,y) and their log-probabilities logπϕ(y∣x). This πϕ becomes πϕold for the optimization step.
- Score Responses: Use the fixed reward model rθ to compute the reward score rθ(x,y) for each generated pair.
- Compute Reference Log-Probabilities: Calculate the log-probabilities of the generated responses y under the fixed reference policy πref, yielding logπref(y∣x).
- Calculate Final Reward: Compute the reward signal used for optimization, typically incorporating the KL penalty: R(x,y)=rθ(x,y)−β(logπϕ(y∣x)−logπref(y∣x)). The KL term is estimated using the difference in log-probabilities between the current/old policy and the reference policy.
- Optimize Policy: Perform multiple epochs of gradient ascent on the PPO clipped surrogate objective LCLIP(ϕ) using the collected batch of experiences (x,y,R(x,y),logπϕold(y∣x)). An optimizer like Adam or AdamW is commonly used to update the parameters ϕ of the language model.
This cycle repeats, iteratively refining the policy πϕ to produce responses that achieve higher rewards according to rθ while staying close to πref.
Practical Considerations
- KL Coefficient (β): Selecting an appropriate value for β is significant. If β is too small, the policy might over-optimize for the reward model, potentially leading to repetitive or nonsensical outputs that exploit reward model weaknesses, or diverging too far from the general language capabilities captured in πref. If β is too large, the policy updates will be overly constrained, and the model might not improve sufficiently in terms of desired alignment characteristics. β can be a fixed hyperparameter or sometimes adjusted dynamically during training based on the observed KL divergence.
- Model Copies: Efficient implementation requires managing several copies or states of the model: the policy being actively trained (πϕ), the policy used for generating the current batch of experience (πϕold), and the frozen reference policy (πref).
- Computational Resources: The PPO step is computationally demanding. It involves:
- A forward pass through πϕ to generate responses.
- A forward pass through the (potentially large) reward model rθ.
- A forward pass through the reference model πref.
- Multiple forward and backward passes through πϕ during the PPO optimization epochs.
In summary, PPO provides a robust and widely adopted method for the policy optimization phase of RLHF. By leveraging the signal from a learned reward model and incorporating a KL divergence penalty within a stable optimization framework, PPO enables fine-tuning large language models to better align with human preferences while preserving their core language understanding and generation abilities.