Okay, we have established the goal: to refine our Supervised Fine-Tuned (SFT) model using human preferences captured by the reward model (RM). Simple supervised updates won't work directly with the scalar reward signal from the RM. Instead, we turn to reinforcement learning (RL) algorithms. Among the various RL techniques, Proximal Policy Optimization (PPO) has emerged as a popular and effective choice for fine-tuning large language models in the RLHF pipeline.
So, what is PPO, and why is it suitable here? PPO belongs to the family of policy gradient methods in RL. The fundamental idea of policy gradient methods is to directly adjust the parameters θ of the policy (in our case, the language model πθ) to maximize the expected rewards. We estimate the gradient of the expected reward and take steps in that direction. However, naive policy gradient implementations can be unstable. A single update step that changes the policy too drastically can lead to a collapse in performance, from which it might be difficult to recover. This is particularly risky with large, complex models like LLMs.
PPO addresses this stability issue by constraining how much the policy can change in each update step. It achieves this through a specific objective function that discourages large deviations from the previous policy while still encouraging improvements based on the reward signal.
The PPO Objective Function
At its core, PPO optimizes a surrogate objective function. The most common variant uses a clipped objective. Let's define some terms first:
- Policy πθ(at∣st): Our language model, parameterized by θ. Given a state st (the input prompt and previously generated tokens), it outputs a probability distribution over the next possible token at.
- Old Policy πθold(at∣st): The policy before the current update iteration. We sample trajectories (sequences of tokens) using this policy.
- Advantage Function A^t: An estimate of how much better taking action at in state st is compared to the average action according to the current policy. It's typically calculated using the reward from our RM and a learned value function V(st) (the critic), often using techniques like Generalized Advantage Estimation (GAE). A positive advantage suggests the action taken was better than expected, while a negative advantage suggests it was worse.
- Probability Ratio rt(θ): This measures the change in probability of taking action at in state st between the new policy πθ and the old policy πθold.
rt(θ)=πθold(at∣st)πθ(at∣st)
The clipped surrogate objective function, often denoted as LCLIP(θ), is then formulated as:
LCLIP(θ)=E^t[min(rt(θ)A^t,clip(rt(θ),1−ϵ,1+ϵ)A^t)]
Let's break this down:
- E^t[...]: This indicates we are taking the average over a batch of timesteps collected from interactions with the environment (i.e., generating text sequences and getting rewards from the RM).
- rt(θ)A^t: This is the standard policy gradient objective. If the advantage A^t is positive, we want to increase the probability of taking action at, so we increase rt(θ). If A^t is negative, we want to decrease the probability, so we decrease rt(θ).
- clip(rt(θ),1−ϵ,1+ϵ): This function clamps the probability ratio rt(θ) to stay within the range [1−ϵ,1+ϵ]. The hyperparameter ϵ (epsilon) is typically a small value like 0.1 or 0.2. It defines the trust region around the old policy.
- clip(rt(θ),1−ϵ,1+ϵ)A^t: This is the clipped version of the objective term.
- min(...): The minimum operator is the essential part.
- If A^t>0 (the action was good): The objective becomes min(rt(θ)A^t,(1+ϵ)A^t). This means the objective is penalized if rt(θ) increases beyond 1+ϵ. We limit how much we can increase the probability of a good action in one step.
- If A^t<0 (the action was bad): The objective becomes min(rt(θ)A^t,(1−ϵ)A^t). Since A^t is negative, this simplifies to max(rt(θ)A^t,(1−ϵ)A^t). This penalizes the objective if rt(θ) decreases below 1−ϵ. We limit how much we can decrease the probability of a bad action in one step.
Essentially, the clipping removes the incentive for the policy to change too drastically, preventing large, potentially destabilizing updates. The policy is encouraged to improve (rt(θ) moves in the direction favored by A^t), but only within the bounds set by ϵ.
Actor-Critic Implementation
PPO is often implemented using an Actor-Critic architecture.
- Actor: The policy network πθ (our LLM) which decides which action (token) to take.
- Critic: The value function network Vϕ(st) (often sharing lower layers with the actor) which estimates the expected return (cumulative future reward) from a given state st. This value estimate is used to compute the advantage A^t.
The overall optimization involves maximizing the LCLIP objective (for the actor) while simultaneously minimizing a loss function for the value function (critic), usually the mean squared error between the predicted value Vϕ(st) and the actual observed returns. An optional entropy bonus term might also be added to the objective to encourage exploration.
Simplified overview of the PPO Actor-Critic loop in RLHF. The Actor (LLM) generates text, the Reward Model provides a reward, and the Critic estimates state value. PPO uses these components to calculate the Advantage and update the Actor and Critic via the Clipped Objective.
Why PPO for LLM Alignment?
Compared to other RL algorithms, PPO strikes a good balance:
- Stability: The clipped objective function provides more stable training updates than simpler policy gradient methods, which is essential for large models where training is expensive and prone to divergence.
- Sample Efficiency: While maybe not as sample-efficient as some off-policy methods in simpler domains, PPO is generally considered more efficient than basic policy gradient methods like REINFORCE. It reuses data collected over multiple epochs of updates within each data collection phase.
- Implementation Complexity: PPO is less complex to implement and tune compared to some alternatives like Trust Region Policy Optimization (TRPO), which involves second-order optimization.
In the context of RLHF, PPO allows us to effectively use the scalar reward signal from the RM to guide the LLM towards generating outputs that align better with human preferences, while mitigating the risk of destabilizing the model during fine-tuning. The next section will detail the specifics of implementing this RL fine-tuning phase using PPO.