In the Proximal Policy Optimization (PPO) algorithm, an actor-critic approach is standard. This involves maintaining two distinct but related networks: a policy network (the actor) that decides which actions to take (i.e., which tokens to generate), and a value network (the critic) that estimates the expected return from a given state. Let's examine how these are typically implemented for fine-tuning large language models in RLHF.
The Policy Network (Actor)
The policy network is the language model itself. Its primary function is to generate text sequences, token by token, given an input prompt.
- Initialization: The policy network is initialized with the weights of the model obtained after the Supervised Fine-Tuning (SFT) phase (covered in Chapter 2). This provides a strong starting point, ensuring the model already possesses good generative capabilities and adheres to the desired style or format learned during SFT.
- Architecture: The architecture is identical to the underlying LLM used (e.g., a Transformer decoder). No structural changes are usually required compared to the SFT model.
- Function: During the RL phase, for a given prompt (state s), the policy network πθ(a∣s) outputs a probability distribution over the vocabulary for the next token (action a). Text generation involves sampling from this distribution sequentially.
- Optimization Goal: The parameters θ of the policy network are updated using the PPO objective function. The goal is to increase the probability of generating sequences that receive high scores from the reward model, while the KL divergence term (discussed in the next section) prevents the policy from deviating too drastically from the initial SFT model's behavior.
The Value Network (Critic)
The value network estimates the expected cumulative future reward from a given state. In PPO, this estimate serves as a baseline to reduce the variance of the policy gradient updates.
- Role: The value network learns the state-value function, Vϕ(s). Given a state s (prompt plus the sequence generated so far), it predicts the total reward the policy expects to receive from that point onwards. This prediction is essential for calculating advantage estimates, often using methods like Generalized Advantage Estimation (GAE), which quantify how much better an action is compared to the average action from that state. The advantage is roughly A(s,a)≈R(s,a)+γVϕ(s′)−Vϕ(s), where R(s,a) is the immediate reward (obtained from the Reward Model), s′ is the next state, and γ is a discount factor.
- Architecture: The value network typically shares the core architecture and parameters with the policy network. It reuses the powerful representations learned by the LLM body (e.g., the Transformer layers). A separate "value head" – usually a simple linear layer – is added on top of the final hidden state representations from the shared LLM body. This head outputs a single scalar value representing the predicted state value Vϕ(s).
- Initialization: While the shared body inherits weights from the SFT model (like the policy network), the value head is usually initialized randomly or with small weights, as it needs to learn a new function (value estimation) from scratch during RL training.
- Optimization Goal: The value network's parameters ϕ (primarily the value head, but also potentially the shared body parameters) are trained concurrently with the policy network. The objective is typically to minimize the Mean Squared Error (MSE) between the predicted values Vϕ(s) and the actual observed returns (targets) calculated during the PPO rollout phase (e.g., using GAE):
L(ϕ)=E(s,Vtarget)[(Vϕ(s)−Vtarget)2]
where Vtarget represents the calculated return target for state s.
Parameter Sharing Strategy
Sharing parameters between the policy and value networks is a common and effective strategy in actor-critic methods, especially for large models like LLMs.
- Efficiency: It significantly reduces the number of parameters to train and store, making the process more computationally feasible.
- Representation Learning: The rich representations learned by the LLM body are beneficial for both generating text (policy) and estimating future rewards (value). Sharing allows the value network to leverage these features learned during pre-training and SFT.
The typical implementation involves using the SFT model as the base. During each forward pass in the RL training loop, the input prompt is processed by the shared LLM body. The final hidden states are then fed into two separate heads:
- The Policy Head: Usually the original language modeling head, outputting logits over the vocabulary.
- The Value Head: A new linear layer outputting a single scalar value.
Diagram illustrating the common parameter-sharing architecture. The input state is processed by the shared LLM body, whose outputs feed into separate heads for policy prediction and value estimation.
Implementation using Libraries
Libraries like Hugging Face's TRL (Transformer Reinforcement Learning) streamline this setup. They often provide wrapper classes, such as AutoModelForCausalLMWithValueHead
. This class takes a standard pre-trained causal language model (like GPT-2, Llama, etc.) and automatically attaches a value head alongside the existing language modeling head.
# Example using Hugging Face TRL (Conceptual Pseudocode)
from transformers import AutoModelForCausalLMWithValueHead, AutoTokenizer
# Load the SFT model as the base
model_name = "path/to/your/sft_model"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLMWithValueHead.from_pretrained(model_name)
# Now 'model' contains both the policy (LM) head and a randomly initialized value head.
# During training, a forward pass yields both outputs:
prompt = "Translate to French: Hello world"
inputs = tokenizer(prompt, return_tensors="pt")
# Forward pass returns logits, past key values, and value estimate
outputs = model(**inputs)
policy_logits = outputs.logits
value_estimate = outputs.value # Scalar output from the value head
# The PPO trainer in TRL handles using these outputs
# for generating actions, calculating advantages, and updating both networks.
By initializing the policy network from the SFT model and using a shared architecture with a dedicated value head, we establish the necessary components for the PPO algorithm. The value network's role in providing accurate state-value estimates is fundamental for stable and effective policy updates during the RL fine-tuning phase. The next sections will detail how these components interact within the PPO update loop, particularly focusing on the KL divergence penalty and advantage calculation.