Fine-tuning large language models using PPO introduces a set of hyperparameters critical for balancing optimization effectiveness with training stability. Unlike standard RL tasks, LLMs are highly complex, non-stationary environments where the policy (the model itself) generates intricate, high-dimensional sequences (text). Small changes in hyperparameters can lead to significant differences in model behavior, training time, and alignment quality. Careful tuning is therefore essential.
This section provides guidance on setting and adjusting the most significant PPO hyperparameters within the RLHF context, focusing on considerations specific to LLMs.
Learning Rates
The learning rate determines the step size taken during gradient descent for both the policy (actor) and value (critic) networks. Given the scale of LLMs and the potential for instability when updating billions of parameters, relatively small learning rates are generally preferred.
- Policy Learning Rate: Controls how quickly the LLM's generation strategy adapts based on the reward signal and the PPO objective. A rate that's too high can cause the policy to diverge drastically from the initial SFT model, leading to high KL divergence, nonsensical outputs, and training collapse. A rate that's too low results in slow learning. Typical values often range from 1×10−6 to 5×10−5.
- Value Learning Rate: Controls how quickly the value network learns to predict the expected return (reward). It can often be slightly higher than the policy learning rate, as value prediction is a supervised regression problem. However, instability in the value function can negatively impact advantage estimation and policy updates. Typical values might range from 1×10−5 to 1×10−4.
It's common practice to use optimizers like Adam or AdamW with learning rate schedulers (e.g., linear decay, cosine decay) to adjust the learning rate over the course of training. Start with conservative (low) learning rates and gradually increase if training appears stable and progress is slow.
Batch Sizes
In PPO for RLHF, two batch sizes are relevant:
- Rollout Batch Size: The number of prompts processed in parallel to generate responses and collect experience (prompts, responses, rewards, log-probabilities). This is often limited by available GPU memory, as it requires running inference with the policy model. Larger rollout batches provide more diverse data per iteration but increase memory requirements. Typical values depend heavily on hardware, ranging from 64 to 1024 or more.
- PPO Mini-batch Size: The size of the data chunk used for calculating gradients during the PPO update epochs. This is sampled from the larger rollout batch. The mini-batch size affects the variance of the gradient estimate. Smaller mini-batches introduce more noise but can sometimes help escape local optima. Larger mini-batches provide more stable gradients but require more memory per update step. Typical values range from 4 to 64, constrained by GPU memory available for gradient computation.
The total amount of experience used per PPO update is rollout_batch_size
. This experience is iterated over ppo_epochs
times, processing mini_batch_size
samples at each gradient step.
PPO Epochs
This hyperparameter defines how many times the PPO algorithm iterates over the collected rollout data (the experience stored in the buffer) to update the policy and value networks.
- More Epochs: Allows the model to learn more from each batch of collected experience, potentially improving sample efficiency. However, too many epochs can lead to overfitting on the current batch of data and can cause the policy to move too far from the policy that generated the data, violating PPO's assumptions and causing instability.
- Fewer Epochs: More stable, as the policy updates are smaller between data collection phases. However, it might require more rollouts (more data collection) to achieve the same level of learning.
Common values range from 2 to 10 epochs. The optimal number often depends on other parameters like the learning rate and mini-batch size.
KL Divergence Coefficient (β)
As discussed previously, the KL divergence term DKL(πθ∣∣πSFT) penalizes the policy πθ for deviating too far from the reference policy πSFT (usually the initial SFT model). The coefficient β controls the strength of this penalty.
- Low β: Allows the policy to explore more aggressively and optimize the reward signal, but increases the risk of deviating significantly from the SFT model's capabilities, potentially leading to incoherent text or loss of general language understanding.
- High β: Keeps the policy close to the SFT model, ensuring stability and preserving language fluency, but may limit the extent to which the model optimizes for the reward signal, potentially resulting in less aligned behavior.
Choosing β is a balancing act. Some implementations use an adaptive KL controller, which adjusts β dynamically during training to keep the KL divergence close to a predefined target value (e.g., 6 nats). This can provide more stability than a fixed β. If using a fixed coefficient, typical starting values might be between 0.01 and 0.2. Monitor the actual KL divergence during training – if it consistently exceeds a desired threshold (e.g., 10-15 nats) or collapses to zero, adjust β accordingly.
KL divergence trends during PPO training under different β settings. Lower fixed β values allow KL to grow larger, while higher values restrict it. An adaptive controller adjusts β to maintain KL near a target value.
PPO Clipping Parameter (ϵ)
PPO limits the policy update size using a clipped surrogate objective function. The clipping parameter ϵ defines the range [1−ϵ,1+ϵ] within which the probability ratio rt(θ)=πold(at∣st)πθ(at∣st) is allowed to operate without being clipped.
- Smaller ϵ (e.g., 0.1): Results in smaller, more conservative policy updates, promoting stability but potentially slowing down convergence.
- Larger ϵ (e.g., 0.3): Allows for larger policy updates, potentially speeding up convergence but increasing the risk of instability if the advantage estimates are noisy or the learning rate is too high.
Typical values for ϵ in PPO are often between 0.1 and 0.3. A common starting point is ϵ=0.2. Its effect is interconnected with the learning rate and the number of PPO epochs.
Generalized Advantage Estimation (GAE) Lambda (λ)
GAE is used to estimate the advantage function A(st,at), balancing bias and variance. The parameter λ controls this trade-off.
- λ=1: Corresponds to high-variance Monte Carlo estimates of the return.
- λ=0: Corresponds to lower-variance but potentially higher-bias Temporal Difference (TD) estimates (using only the immediate reward and the next state's value).
- 0<λ<1: Interpolates between these extremes.
In practice, values close to 1, such as λ=0.95, are often used in PPO implementations, including those for RLHF. This choice generally provides a good balance for policy gradient updates.
Value Function Coefficient (c1)
The overall PPO loss function combines the clipped surrogate objective, the value function loss, and sometimes an entropy bonus. The value function coefficient (c1, often denoted vf_coef
) scales the mean squared error loss between the predicted values Vϕ(st) and the target values (usually computed via GAE).
LCLIP+VF+S(θ,ϕ)=Et[LtCLIP(θ)−c1LtVF(ϕ)+c2S[πθ](st)]
Where LtVF(ϕ)=(Vϕ(st)−Vttarg)2.
A typical value for c1 is 0.5 or 1.0. This ensures that the value function is trained effectively alongside the policy, as accurate value estimates are important for good advantage estimation.
Strategies for Tuning
Tuning PPO for LLMs is often an empirical process:
- Start with Defaults: Begin with hyperparameters reported in successful RLHF studies or default values provided by libraries like TRL (e.g.,
PPOConfig
).
- Prioritize KL Control: Ensure the KL divergence remains within a reasonable range (e.g., < 15-20 nats). Adjust β or the adaptive KL target first if the policy diverges too quickly.
- Monitor Key Metrics: Track the reward mean/distribution, KL divergence, policy loss, value loss, and model entropy during training. Use evaluation prompts periodically to assess generation quality.
- Tune Learning Rate and Batch Size: If training is stable but slow, consider slightly increasing the learning rate or adjusting batch sizes (if memory permits). If unstable, decrease the learning rate.
- Adjust PPO Epochs and Clipping: If experiencing instability with multiple epochs, try reducing the number of epochs or tightening the clipping range (ϵ).
- Iterate: Make small, incremental changes and observe their effects over several training steps. Hyperparameter interactions are complex, so isolating the effect of a single change can be difficult but is often necessary.
Due to the high computational cost of training LLMs, extensive grid searches are often infeasible. Rely on established ranges, monitor training dynamics closely, and make informed adjustments based on observed behavior and evaluation results.