To effectively optimize the policy using the PPO objective, we need a reliable estimate of how much better a chosen action (generating a specific token) is compared to the average action the policy would take in that state (the current generated sequence). This is the role of the advantage function, A(st,at). Directly using the immediate reward rt (which includes the reward model signal and the KL penalty) is insufficient because it ignores the long-term consequences of an action. We need to consider the total accumulated reward, or return, and compare it against a baseline provided by a value function V(st), which estimates the expected return from state st.
The return Gt is the total discounted reward received starting from time step t until the end of the episode (the generated sequence). For a sequence of length T, it's defined as:
Gt=k=0∑T−t−1γkrt+k+1Here, rt+k+1 is the reward received after taking an action in state st+k. In the RLHF context for LLMs, the reward rt at each step typically consists of two components: a penalty based on the KL divergence between the current policy and the reference (SFT) policy, and potentially a contribution from the final reward model score R(x) assigned to the complete sequence x. A common practice is to apply the KL penalty rKL,t=−βDKL(πθ(⋅∣st)∣∣πref(⋅∣st)) at each token generation step t, and add the final reward model score R(x) only at the last step, T. So, rt=rKL,t for t<T, and rT=rKL,T+R(x).
The discount factor γ∈[0,1] determines the present value of future rewards. A γ closer to 1 gives more weight to future rewards, while a γ closer to 0 prioritizes immediate rewards. For text generation tasks, γ is often set close to 1 (e.g., 0.99 or 1.0) because the quality assessment (via the reward model) often depends on the complete sequence.
The advantage function A(st,at) measures the relative value of taking action at in state st compared to the expected value of the state V(st) under the current policy πθ. It's formally defined as:
A(st,at)=Q(st,at)−V(st)where Q(st,at) is the action-value function, representing the expected return after taking action at in state st and following the policy πθ thereafter. Since we typically learn V(st) directly using a value network (the critic), we can estimate Q(st,at) using the immediate reward rt+1 and the value of the next state V(st+1):
Q(st,at)≈rt+1+γV(st+1)Substituting this into the advantage definition gives the one-step Temporal Difference (TD) error, δt, often used as a basic advantage estimator:
A^t≈δt=rt+1+γV(st+1)−V(st)This estimate tells us whether the observed outcome (rt+1+γV(st+1)) was better or worse than what was expected (V(st)).
While the TD error δt is an unbiased estimate of the advantage (if V is accurate), it can suffer from high variance because it relies heavily on the single-step reward rt+1 and the next-state value estimate V(st+1). This variance can make the PPO updates unstable, especially in complex tasks like language generation.
Generalized Advantage Estimation (GAE) is a technique designed to reduce this variance by incorporating information from multiple time steps, effectively blending the single-step TD error with longer-term Monte Carlo returns. GAE introduces a parameter λ∈[0,1] (often called the GAE lambda) to control this trade-off.
The GAE advantage estimator is calculated as an exponentially weighted sum of TD errors:
A^tGAE(γ,λ)=k=0∑T−t−1(γλ)kδt+kwhere δt+k=rt+k+1+γV(st+k+1)−V(st+k) is the TD error at time step t+k. (Note: V(sT) is typically defined as 0 if sT is a terminal state).
The diagram below illustrates how TD errors over multiple steps contribute to the GAE calculation for A^t.
Calculation flow for Generalized Advantage Estimation (GAE). Rewards (r) and value function estimates (V) for subsequent states (s) are used to compute Temporal Difference (TD) errors (δ). These TD errors are then combined using weights based on γ and λ to form the final GAE advantage estimate A^tGAE.
In a typical RLHF implementation using libraries like TRL (Transformer Reinforcement Learning), GAE is computed efficiently. During the PPO rollout phase, the policy generates sequences token by token. For each token generated up to the end of the sequence (or a maximum length), the following are stored:
Once a batch of complete sequences is generated:
It is a standard practice to normalize the advantages across a batch before using them in the PPO loss calculation. This involves subtracting the mean and dividing by the standard deviation of the advantages within the batch, which helps stabilize training by preventing excessively large policy updates.
By carefully calculating returns and using GAE to estimate advantages, we provide the PPO algorithm with a stable and informative signal to guide the LLM policy towards generating responses that align better with the preferences captured by the reward model, while mitigating the instability associated with high-variance gradient estimates.
© 2025 ApX Machine Learning