As established, supervised fine-tuning alone often struggles to capture the full spectrum of desired behaviors for large language models. This is where Reinforcement Learning (RL) steps in, providing a framework to optimize the LLM based on a learned signal representing human preferences. Given the prerequisites for this course, you likely have a background in RL. This section serves as a focused refresher on the principles most pertinent to RLHF, particularly those underlying Proximal Policy Optimization (PPO).
Framing LLM Generation as an RL Problem
At its core, RL deals with agents learning to make sequences of decisions in an environment to maximize a cumulative reward. We can frame the task of generating text with an LLM using the standard RL formalism, specifically the Markov Decision Process (MDP):
- State (s): The current context, typically comprising the initial prompt and the sequence of tokens generated so far.
- Action (a): The next token to be generated by the LLM. The action space is the vocabulary of the language model.
- Policy (πθ(a∣s)): The LLM itself, parameterized by θ. It defines a probability distribution over the next possible tokens (actions) given the current state (prompt + previous tokens).
- Reward (R(s,a)): This is the critical component learned from human feedback. After a sequence (response) is generated, a reward signal reflects how well that sequence aligns with human preferences. This is often assigned only at the end of the sequence, based on the output of a Reward Model (RM), which we will detail in Chapter 3.
- Transition (P(s′∣s,a)): In text generation, the transition is typically deterministic. Given the current state s (e.g., "Translate 'hello' to French:") and the chosen action a (e.g., the token "Bonjour"), the next state s′ is simply the concatenation (e.g., "Translate 'hello' to French: Bonjour").
The objective in this RL setting is to adjust the LLM's parameters θ to find a policy πθ that maximizes the expected cumulative reward (often called the expected return) for the generated sequences:
J(θ)=Eτ∼πθ[t=0∑TγtR(st,at)]
Here, τ=(s0,a0,s1,a1,...) represents a full trajectory (e.g., a complete generated response), πθ dictates the probability of generating that trajectory, T is the length of the sequence, and γ is a discount factor (often set to 1 for finite-horizon text generation).
Policy Gradients and the Need for Stability
Policy Gradient methods directly optimize the policy parameters θ by performing gradient ascent on the objective J(θ). A common form of the policy gradient theorem gives us:
∇θJ(θ)=Eτ∼πθ[t=0∑T∇θlogπθ(at∣st)A^t]
Where A^t is an estimate of the advantage function at timestep t. The advantage A(s,a)=Q(s,a)−V(s) represents how much better taking action a in state s is compared to the average action according to the current policy, estimated by the value function V(s). Intuitively, this increases the probability of actions that lead to higher-than-expected rewards and decreases the probability of actions leading to lower-than-expected rewards.
While effective, basic policy gradient methods can suffer from high variance and instability, especially when dealing with high-dimensional parameter spaces like those found in LLMs. A single large gradient update, driven by a particularly high or low reward sample, could drastically alter the policy, potentially leading to a collapse in performance. This instability motivated the development of more sophisticated algorithms like PPO.
Proximal Policy Optimization (PPO)
PPO has become the standard RL algorithm for fine-tuning LLMs in the RLHF pipeline due to its relative simplicity, stability, and strong empirical performance. It addresses the instability issue of vanilla policy gradients by constraining how much the policy can change in each update step.
PPO optimizes a surrogate objective function that incorporates a mechanism to discourage large policy updates. The most common variant uses a clipped objective:
LCLIP(θ)=Et[min(rt(θ)A^t,clip(rt(θ),1−ϵ,1+ϵ)A^t)]
Let's break this down:
- Et[...] indicates taking the average over a batch of collected experience (timesteps).
- rt(θ)=πθold(at∣st)πθ(at∣st) is the probability ratio between the current policy πθ (being optimized) and the old policy πθold (used to collect the data). It measures how much the policy has changed.
- A^t is the estimated advantage for timestep t, often calculated using Generalized Advantage Estimation (GAE), which helps reduce variance. This requires learning a value function V(s) (the "critic") alongside the policy (the "actor").
- clip(rt(θ),1−ϵ,1+ϵ) restricts the ratio rt(θ) to stay within the interval [1−ϵ,1+ϵ]. The hyperparameter ϵ (e.g., 0.2) defines the clipping range.
- min(...) takes the minimum of two terms:
- The unclipped objective rt(θ)A^t.
- The clipped objective clip(rt(θ),1−ϵ,1+ϵ)A^t.
The effect of this clipping is to limit the influence of actions where the policy ratio rt(θ) moves outside the [1−ϵ,1+ϵ] boundary. If the advantage A^t is positive (meaning the action was better than average), the objective is clipped from above if the policy change is too large (rt>1+ϵ), preventing overly aggressive updates. If the advantage A^t is negative (the action was worse than average), the objective is clipped from below if the policy change is too large (rt<1−ϵ), preventing overly drastic reductions in probability for that action. This keeps the new policy close to the old policy, ensuring more stable learning.
The KL Divergence Constraint in RLHF
While PPO's clipped objective inherently encourages smaller updates, RLHF implementations often add an explicit penalty term based on the Kullback-Leibler (KL) divergence between the current policy πθ and a reference policy πref. Typically, πref is the initial SFT model.
This KL penalty is usually incorporated directly into the reward signal used during PPO training:
Rtotal(s,a)=RRM(s,a)−βDKL(πθ(⋅∣s)∣∣πref(⋅∣s))
Or, more commonly in practice for per-token rewards:
Rtoken(st,at)=RRM,t−βlogπref(at∣st)πθ(at∣st)
Here:
- RRM is the reward obtained from the Reward Model.
- β is a hyperparameter controlling the strength of the KL penalty.
- The KL term penalizes the policy πθ for diverging too much from the reference policy πref on a token-by-token basis.
Why is this KL penalty so significant in RLHF?
- Preserving Capabilities: It prevents the LLM, while optimizing for the human preference reward RRM, from drifting too far away from the general language modeling capabilities and knowledge embedded in the original SFT model. Without it, the model might learn to generate repetitive or nonsensical text that happens to score highly on the RM ("reward hacking") but is otherwise low quality.
- Stabilization: Similar to the clipping in PPO, it acts as a regularization term, ensuring smoother and more stable training updates.
We will explore the practical implementation and tuning of this KL penalty within the PPO algorithm in Chapter 4.
Tying it Together for RLHF
This refresher highlights the core RL components adapted for RLHF:
- We frame text generation as an MDP where the LLM acts as the policy.
- The goal is to maximize a reward signal derived from human preferences (via the Reward Model).
- PPO is employed to update the LLM's policy parameters stably.
- A KL divergence penalty relative to the initial SFT model is incorporated to maintain language quality and prevent catastrophic forgetting or reward hacking.
Understanding these RL foundations, particularly the mechanics and rationale behind PPO and the KL constraint, is essential for effectively implementing and troubleshooting the RL fine-tuning stage of the RLHF pipeline, which we will detail in the upcoming chapters.