With the Supervised Fine-Tuned (SFT) model providing a strong baseline for instruction following and the Reward Model (RM) trained to score responses based on human preferences, we now integrate these components into the Reinforcement Learning (RL) phase. The objective here is to further refine the SFT model, optimizing it to generate outputs that maximize the scores assigned by our RM, effectively steering the model towards behaviors deemed preferable by humans. Proximal Policy Optimization (PPO) is the workhorse algorithm for this phase in most RLHF pipelines due to its relative stability and sample efficiency compared to other RL algorithms.
The RL Framework in LLM Fine-Tuning
Let's map the standard RL terminology to our LLM context:
- Agent/Policy (πθ): This is the language model we are actively fine-tuning. It starts as a copy of the SFT model, and its parameters θ are updated during the PPO process. The policy πθ(a∣s) defines the probability of generating the next token a given the current sequence of tokens s.
- Action (a): The action corresponds to selecting the next token from the model's vocabulary to append to the sequence.
- State (s): The state represents the sequence of tokens generated so far, starting with the initial prompt.
- Environment: The environment is implicitly defined. It takes the model's generated sequence (action sequence) in response to a prompt (initial state) and returns a reward. The core components are the prompt distribution and the reward function (our RM + KL penalty).
- Reward Function (R(s,a) or R(sequence)): This is the critical signal guiding the optimization. In RLHF, the reward for a complete generated sequence (prompt + response) is primarily determined by the score from the trained Reward Model (RM). To maintain linguistic coherence and prevent the policy from deviating too drastically from the reliable SFT model, this RM score is combined with a penalty term based on the Kullback-Leibler (KL) divergence between the current policy πθ and the initial SFT policy πref.
The PPO Iteration Cycle
The PPO fine-tuning process is iterative, typically involving these steps in each iteration:
-
Rollout: Sample a batch of prompts (e.g., from the SFT training distribution or a separate prompt dataset). For each prompt, generate a response using the current policy model πθ. This involves autoregressively sampling tokens until an end-of-sequence token is produced or a maximum length is reached. During this generation, we need to store several pieces of information for each token step t:
- The state st (prompt + tokens generated up to step t).
- The action at (the token generated at step t).
- The log probability of generating that token under the current policy: logπθ(at∣st).
- The log probability of generating that token under the reference SFT policy: logπref(at∣st). The reference policy πref is kept fixed throughout the RL phase and is typically the initial SFT model state.
-
Reward Calculation: Once a complete sequence (prompt + response) is generated, calculate the reward. This typically involves:
- Scoring the final sequence using the trained Reward Model (RM). Let's call this RRM. This score reflects the alignment quality based on human preferences.
- Calculating a per-token KL divergence penalty for each step t: RKLt=−β(logπθ(at∣st)−logπref(at∣st)). The hyperparameter β controls the strength of this penalty. A higher β discourages the model from diverging from the SFT model.
- Combining these into a final reward signal. A common approach is to assign the RRM score only at the final token step and add the per-token RKL at each step. So, the total reward for a sequence might look like a sequence of per-token KL penalties, with the RM score added to the reward for the very last token.
Let's illustrate calculating the log probabilities and the KL penalty component in PyTorch:
import torch
import torch.nn.functional as F
# Assume 'policy_model' is the current model being trained (pi_theta)
# Assume 'ref_model' is the frozen SFT model (pi_ref)
# Assume 'prompt_tokens' is the input tensor for the prompt
# Assume 'generated_tokens' is the tensor of tokens generated by policy_model
# Concatenate prompt and generated tokens for model input
input_ids = torch.cat([prompt_tokens, generated_tokens], dim=-1)
attention_mask = (input_ids != pad_token_id).long() # Assuming pad_token_id
# Get logits from both models
with torch.no_grad(): # No gradients needed for ref_model or for calculating KL penalty inputs
policy_outputs = policy_model(input_ids=input_ids, attention_mask=attention_mask)
ref_outputs = ref_model(input_ids=input_ids, attention_mask=attention_mask)
policy_logits = policy_outputs.logits
ref_logits = ref_outputs.logits
# Shift logits and labels for next token prediction format
# We only care about the probabilities assigned to the *generated* tokens
gen_len = generated_tokens.size(1)
policy_log_probs = F.log_softmax(policy_logits[:, -gen_len-1:-1, :], dim=-1)
ref_log_probs = F.log_softmax(ref_logits[:, -gen_len-1:-1, :], dim=-1)
# Gather the log probabilities of the actual generated tokens
# generated_tokens needs to be reshaped correctly for gather
# Shape: (batch_size, gen_len, 1)
gathered_policy_log_probs = policy_log_probs.gather(dim=-1, index=generated_tokens.unsqueeze(-1)).squeeze(-1)
gathered_ref_log_probs = ref_log_probs.gather(dim=-1, index=generated_tokens.unsqueeze(-1)).squeeze(-1)
# Calculate per-token KL penalty component (log ratio)
# Shape: (batch_size, gen_len)
log_ratio = gathered_policy_log_probs - gathered_ref_log_probs
kl_penalty_per_token = -beta * log_ratio # beta is the KL coefficient
# Now, get the RM score for the full sequences (implementation depends on RM architecture)
# full_sequences = torch.cat([prompt_tokens, generated_tokens], dim=-1)
# rm_scores = reward_model(full_sequences).score # Example RM output
# Combine rewards (simplified example: RM score added at the end)
rewards = kl_penalty_per_token.clone()
rewards[:, -1] += rm_scores
PyTorch snippet illustrating the calculation of log probabilities from the policy and reference models and the resulting KL penalty term. beta
is an important hyperparameter.
-
Optimization (PPO Update): Use the collected trajectories (states, actions, log probabilities, rewards) to update the policy model πθ. PPO achieves this by optimizing a surrogate objective function. The core idea is to maximize the expected advantage At=Q(st,at)−V(st), which represents how much better an action at is compared to the average action from state st. PPO uses importance sampling with a clipped objective to ensure stable updates:
- Advantage Estimation: Calculate the advantage estimates A^t for each step. While vanilla policy gradients use the return Gt (sum of future rewards), PPO often employs Generalized Advantage Estimation (GAE) for lower variance. In many RLHF implementations, the RM score itself (potentially normalized or combined with a learned value function V(st)) serves as a proxy for Q(st,at) or the advantage. A separate value model, often initialized from the RM or SFT model and trained to predict the expected cumulative reward (including KL penalty), is commonly used to compute V(st) and refine advantage estimates.
- Policy Update: Compute the probability ratio rt(θ)=πθold(at∣st)πθ(at∣st)=exp(logπθ(at∣st)−logπθold(at∣st)), where πθold is the policy before the update (from the rollout phase). The clipped surrogate objective is:
LCLIP(θ)=E^t[min(rt(θ)A^t,clip(rt(θ),1−ϵ,1+ϵ)A^t)]
The clip function restricts the ratio rt(θ) to the interval [1−ϵ,1+ϵ], preventing excessively large updates that could destabilize training. ϵ is a small hyperparameter (e.g., 0.2).
- Perform multiple epochs of stochastic gradient ascent on this objective using the collected batch of rollout data. Update the parameters θ of the policy model. If a separate value model is used, it's typically updated concurrently by minimizing the mean squared error between its predictions and the observed returns.
This cycle of rollout, reward calculation, and PPO optimization repeats, gradually shifting the policy πθ to generate responses that receive higher scores from the RM while the KL penalty keeps it anchored to the fluency and capabilities learned during SFT.
Practical Considerations
- Value Function: As mentioned, training a separate value function alongside the policy is standard practice in PPO to reduce the variance of advantage estimates. This value function predicts the expected discounted future reward from a given state (sequence of tokens). It's typically trained using the squared error between its predictions and the observed returns from the rollouts.
- Hyperparameter Tuning: RLHF, and PPO in particular, involves several sensitive hyperparameters: the KL coefficient β, the PPO clipping parameter ϵ, learning rates for the policy and value models, the number of PPO epochs per rollout, batch sizes, and GAE parameters (λ, γ). Finding the right balance is often empirical and computationally intensive.
- Computational Cost: RLHF is computationally demanding. It requires maintaining multiple models (policy, reference, RM, potentially a value model), performing rollouts (inference), calculating rewards (inference), and performing PPO updates (training). Frameworks like DeepSpeed and libraries like
trl
(Transformer Reinforcement Learning) provide optimized implementations to manage this complexity.
By carefully implementing this PPO-based fine-tuning loop, we can effectively leverage the signal captured in the Reward Model to align the language model's behavior more closely with desired human characteristics like helpfulness and harmlessness, building upon the foundation laid by Supervised Fine-Tuning.