Masterclass
While Proximal Policy Optimization (PPO) provides the mechanism for updating our language model based on the reward signal from the trained Reward Model (RM), applying it naively can lead to undesirable outcomes. The RL policy (πRL) might discover sequences that achieve high scores from the RM but are unnatural, repetitive, nonsensical, or drastically different in style from the original Supervised Fine-Tuned (SFT) model (πSFT). This phenomenon is sometimes called "reward hacking" or "mode collapse," where the model over-optimizes for the proxy reward signal, losing the general language capabilities learned during pre-training and SFT.
To mitigate this, the standard RLHF pipeline introduces a penalty term into the PPO objective function based on the Kullback-Leibler (KL) divergence between the RL policy's output distribution and the SFT model's output distribution.
The KL divergence, denoted as KL(P∣∣Q), measures how one probability distribution P differs from a reference probability distribution Q. A KL divergence of zero indicates that the distributions are identical. In the context of RLHF, we want to measure how much the current RL policy πRL(y∣x) diverges from the reference SFT policy πSFT(y∣x) for a given input prompt x and potential output token y.
The goal is not just to maximize the reward from the RM but to do so without straying too far from the behavior of the well-behaved SFT model. We incorporate this constraint directly into the reward signal used by PPO. The modified reward for generating a sequence y=(y1,...,yT) given prompt x becomes:
Total Reward(x,y)=RRM(x,y)−β⋅KL(πRL(y∣x)∣∣πSFT(y∣x))Here:
Calculating the exact KL divergence over entire sequences is often intractable. Instead, PPO typically operates on a per-token level. During the generation (rollout) phase of PPO, for each token yt generated at step t given the context x,y<t, we calculate:
The per-token KL penalty term is often approximated by the difference between these log probabilities:
Per-token KL penalty approximation≈logπRL(yt∣x,y<t)−logπSFT(yt∣x,y<t)This term is subtracted (after scaling by β) from the per-token reward provided by the RM before calculating advantages and updating the policy. The reference SFT model's parameters are kept frozen during this RL phase; it only performs forward passes to provide the reference log probabilities.
Let's look at a PyTorch snippet illustrating how this penalty might be calculated for a batch of generated sequences. Assume we have the log probabilities for each token in the vocabulary at each position from both models.
import torch
import torch.nn.functional as F
# Assume log_probs_rl and log_probs_sft are log-softmax outputs
# from models
# Shape: [batch_size, sequence_length, vocab_size]
# actions: Indices of the tokens actually generated by the RL policy
# during rollout
# Shape: [batch_size, sequence_length]
def calculate_kl_penalty(log_probs_rl, log_probs_sft, actions, beta):
"""
Calculates the KL penalty term used in RLHF PPO.
Args:
log_probs_rl: Log probabilities from the RL policy model.
log_probs_sft: Log probabilities from the reference SFT model
(detached).
actions: The token IDs generated by the RL policy.
beta: The KL penalty coefficient.
Returns:
kl_penalty: A tensor of KL penalties for each token position.
Shape: [batch_size, sequence_length]
"""
# Ensure log_probs_sft doesn't contribute to gradients
log_probs_sft_detached = log_probs_sft.detach()
# Gather the log probabilities of the specific actions taken
log_prob_rl_taken = torch.gather(
log_probs_rl, 2, actions.unsqueeze(2)
).squeeze(2)
log_prob_sft_taken = torch.gather(
log_probs_sft_detached, 2, actions.unsqueeze(2)
).squeeze(2)
# Calculate the approximate per-token KL divergence (log ratio)
log_ratio = log_prob_rl_taken - log_prob_sft_taken
# Scale by beta
kl_penalty = beta * log_ratio
return kl_penalty
# --- Example Usage ---
# batch_size, seq_len, vocab_size = 4, 50, 32000
# log_probs_rl = torch.randn(
# batch_size, seq_len, vocab_size
# ).log_softmax(dim=-1)
# log_probs_sft = torch.randn(
# batch_size, seq_len, vocab_size
# ).log_softmax(dim=-1)
# actions = torch.randint(0, vocab_size, (batch_size, seq_len))
# beta_kl = 0.1
# penalty = calculate_kl_penalty(
# log_probs_rl, log_probs_sft, actions, beta_kl
# )
# print(f"KL Penalty Shape: {penalty.shape}")
# # Output: KL Penalty Shape: torch.Size([4, 50])
# This 'penalty' would typically be subtracted from the reward
# provided by the RM before calculating advantages in the PPO algorithm.
# reward_signal = reward_from_rm - penalty
This code snippet demonstrates the common practice of using the log probability ratio of the actions taken as the KL penalty term. It's computationally cheaper than calculating the full KL divergence over the entire vocabulary at each step but serves the same purpose of penalizing deviations.
The KL penalty acts as an important regularization term. It prevents the RL policy from collapsing onto narrow modes that exploit the reward model while sacrificing linguistic quality. By tying the RL policy back to the SFT policy, it helps ensure that the model retains its fluency, coherence, and the broad knowledge embedded within it.
Choosing the right value for β is important.
Finding an appropriate β often involves empirical tuning and careful evaluation of the trade-off between maximizing the RM score and maintaining desirable linguistic properties.
The PPO objective aims to maximize the reward from the Reward Model while minimizing the KL penalty, scaled by β, which discourages deviation from the original SFT model's behavior.
In summary, the KL divergence penalty is a standard technique in RLHF used to stabilize training and prevent the language model from deviating excessively from its initial SFT state. It encourages the model to find outputs that are preferred according to human feedback (as captured by the RM) while preserving the fluency and general capabilities learned during its prior training stages.
© 2025 ApX Machine Learning