Having collected a batch of experience by rolling out the current policy against prompts, receiving rewards from the reward model (adjusted by the KL penalty), and estimating state values, we arrive at the core learning phase: the PPO update step. This is where the policy and value networks are adjusted based on the collected data to maximize the expected reward while adhering to the PPO constraints.
This section focuses on the practical implementation details of calculating the PPO loss components and performing the gradient updates. We assume you have already computed the advantages (At) and returns (Rt) for each timestep t in your experience batch, typically using Generalized Advantage Estimation (GAE) as discussed previously.
Recall that PPO aims to optimize a clipped surrogate objective function. The overall loss function typically combines terms for the policy (actor) and the value function (critic), and often includes an entropy bonus to encourage exploration.
Policy Loss (LCLIP): This is the heart of PPO. It uses a ratio rt(θ)=πθold(at∣st)πθ(at∣st) which measures how much the policy has changed since the data was collected. To prevent large, destabilizing updates, this ratio is clipped. The objective term for a single timestep is:
LtCLIP(θ)=min(rt(θ)At,clip(rt(θ),1−ϵ,1+ϵ)At)Here, ϵ is the clipping hyperparameter (e.g., 0.2), At is the advantage estimate at timestep t, and πθold represents the policy that generated the data. Since optimizers typically minimize loss, we often take the negative of the expectation (average) of this term over the batch.
Value Loss (LVF): The value network Vϕ(st) is trained to predict the expected return (sum of future rewards) from state st. It's typically trained using a mean squared error (MSE) loss against the computed target returns Rt:
LtVF(ϕ)=(Vϕ(st)−Rt)2Some PPO implementations also clip the value function updates to further improve stability, similar to the policy loss clipping, although this is less common than policy clipping.
Entropy Bonus (S): To encourage exploration and prevent the policy from collapsing to deterministic outputs too quickly, an entropy bonus is often added to the objective (or subtracted from the loss). For a categorical policy (like language models predicting tokens), the entropy H is calculated as:
H(πθ(⋅∣st))=−a′∑πθ(a′∣st)logπθ(a′∣st)This term is usually averaged over the batch and weighted by a coefficient c2.
The final loss minimized by the optimizer combines these components:
LPPO(θ,ϕ)=Et[−LtCLIP(θ)+c1LtVF(ϕ)−c2H(πθ(⋅∣st))]Where c1 and c2 are hyperparameters balancing the value loss and entropy bonus contributions. Note that the KL divergence penalty, discussed earlier, is typically incorporated into the reward signal used to calculate advantages (At) and returns (Rt), rather than being a separate term in this PPO loss function. This simplifies the update step while still ensuring the policy doesn't deviate too drastically from the reference (SFT) model.
In practice, the update step involves iterating over the collected batch of experience multiple times (epochs) using mini-batches. Here’s a breakdown of the process within a single PPO update epoch for one mini-batch:
Load Mini-batch: Get a subset of the collected experience data (prompts, generated sequences, log probabilities under πθold, rewards, returns Rt, advantages At, attention masks, etc.).
Forward Pass: Pass the sequences from the mini-batch through the current policy (πθ) and value ( Vϕ) networks.
Calculate Ratios: Compute the probability ratio rt(θ) using the newly computed log probabilities and the old log probabilities stored in the mini-batch:
rt(θ)=exp(logπθ(at∣st)−logπθold(at∣st))Calculate Policy Loss: Compute the clipped surrogate objective using the calculated ratios rt(θ), the advantages At from the mini-batch, and the clipping parameter ϵ. Average the negative of this term over the mini-batch.
# Pseudocode using PyTorch-like syntax
import torch
# Assuming advantages, old_log_probs are part of the mini-batch
# current_log_probs are computed from the current policy model
ratio = torch.exp(current_log_probs - batch['old_log_probs'])
surr1 = ratio * batch['advantages']
surr2 = torch.clamp(ratio, 1.0 - clip_epsilon, 1.0 + clip_epsilon) * batch['advantages']
# Negative sign because optimizers minimize loss
policy_loss = -torch.min(surr1, surr2).mean()
Calculate Value Loss: Compute the MSE between the current value estimates Vϕ(st) and the target returns Rt from the mini-batch.
# Pseudocode
# current_values are computed from the current value model
# returns are part of the mini-batch
value_loss = ((current_values - batch['returns'])**2).mean()
Optional: Implement value clipping if desired.
Calculate Entropy Bonus (Optional): Compute the entropy of the policy's output distribution. This often involves accessing the full probability distribution over the vocabulary from the policy network's forward pass. Average this entropy and negate it (as we minimize loss).
# Pseudocode (simplified)
# logits from policy model forward pass
probs = torch.softmax(logits, dim=-1)
log_probs = torch.log_softmax(logits, dim=-1)
entropy = -(probs * log_probs).sum(dim=-1).mean()
# Store -entropy as the bonus term to subtract from the loss
entropy_bonus = -entropy # Remember the coefficient c2 later
Combine Losses: Compute the total loss using the coefficients c1 and c2.
# Pseudocode
vf_coef = 0.5 # Example coefficient c1
entropy_coef = 0.01 # Example coefficient c2
total_loss = policy_loss + vf_coef * value_loss - entropy_coef * entropy_bonus
# Note: entropy_bonus was already negated entropy, so subtracting it adds the entropy term
Backward Pass and Optimization: Perform backpropagation on the total_loss
to compute gradients for both the policy and value network parameters. Update the parameters using an optimizer like Adam or AdamW.
# Pseudocode
optimizer.zero_grad()
total_loss.backward()
# Optional: Gradient clipping can be applied here
# torch.nn.utils.clip_grad_norm_(policy_model.parameters(), max_grad_norm)
# torch.nn.utils.clip_grad_norm_(value_model.parameters(), max_grad_norm)
optimizer.step()
Repeat steps 1-8 for all mini-batches for the specified number of PPO epochs.
float16
or bfloat16
can significantly speed up training and reduce memory usage, especially important for large LLMs.This practice step forms the computational core of the PPO fine-tuning phase. By repeatedly sampling experience and updating the policy and value networks using this carefully constructed objective, the language model gradually learns to generate responses that better align with the preferences encoded in the reward model, while the KL penalty and PPO clipping mechanism prevent it from changing too drastically or unstably. Mastering this update step is fundamental to successful RLHF implementation.
© 2025 ApX Machine Learning