The RLHF pipeline relies on the interplay of several distinct models, each serving a specific purpose. Correctly loading these models from their respective training stages (like SFT and Reward Modeling) and initializing them for the PPO phase is a fundamental step before RL fine-tuning can begin. This section details the typical loading and initialization strategies.
At the start of the PPO phase (Stage 3), you'll typically need access to the following models:
Managing these potentially large models requires careful handling of checkpoints and computational resources. Libraries like Hugging Face Transformers provide convenient methods for loading models trained in previous stages.
1. Initializing the Policy and Reference Models:
Both the Policy Model (the one being trained) and the Reference Model (the frozen one for KL divergence) typically start from the same set of weights: the checkpoint saved after the Supervised Fine-Tuning (SFT) phase (Stage 1). Loading the SFT model provides a strong starting point for the policy, as it's already adapted to the desired style or domain from the SFT dataset.
You would load the SFT model checkpoint twice. One instance becomes the policy_model
, whose weights will be updated during PPO training. The second instance becomes the ref_model
, which must be set to evaluation mode (.eval()
) and its parameters frozen to prevent any updates. Many RLHF libraries, like TRL (Transformer Reinforcement Learning), often handle the reference model implicitly, using the initial state of the policy model as the reference.
2. Initializing the Value Model (Critic):
The Value Model's initialization depends on the architecture choice:
AutoModelForCausalLMWithValueHead
which combine the language model (actor) and a value head (critic) that outputs a scalar value. When loading from an SFT checkpoint using such a class, the base LM weights are loaded from the SFT model, while the value head is typically initialized randomly or using pre-trained weights if available (though often requires fine-tuning). This value head is then trained alongside the policy during PPO.The Value Model's role is critical for calculating advantage estimates (e.g., using Generalized Advantage Estimation - GAE) within the PPO update step.
3. Loading the Reward Model:
The Reward Model (RM) is loaded from the checkpoint saved after Stage 2 (RM training). It's typically a sequence classification model (e.g., AutoModelForSequenceClassification
in Transformers) with a single scalar output representing the reward. Since it's only used for scoring generated responses during PPO and is not trained further in this stage, it should be loaded and immediately put into evaluation mode (.eval()
) with gradient calculations disabled for inference.
Using libraries like Hugging Face Transformers and TRL simplifies this process significantly. Here’s an illustrative example:
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer
# Assume these paths point to your trained model checkpoints
sft_model_path = "./path/to/your/sft_model_checkpoint"
rm_model_path = "./path/to/your/reward_model_checkpoint"
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# --- Configuration ---
# PPO configuration includes hyperparameters for training
ppo_config = PPOConfig(
batch_size=32,
learning_rate=1.41e-5,
# other PPO params like kl_penalty, epochs, etc.
)
# --- Load Tokenizer ---
# Usually shared across models, loaded from the SFT model path
tokenizer = AutoTokenizer.from_pretrained(sft_model_path)
# Ensure padding token is set if needed (common for batch processing)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# --- Load Models ---
# 1. Policy Model (Actor + Value Head)
# This model has both the LM head for generation and a scalar head for value estimation.
# Initialized from SFT checkpoint. It will be trained by PPO.
print(f"Loading Policy/Value model from: {sft_model_path}")
policy_model = AutoModelForCausalLMWithValueHead.from_pretrained(
sft_model_path,
torch_dtype=torch.bfloat16, # Use bfloat16 for efficiency on compatible hardware
# low_cpu_mem_usage=True, # Useful for large models
device_map={"": device} # Load directly to the target device
)
print("Policy/Value model loaded.")
# 2. Reward Model (for scoring generations)
# Initialized from RM checkpoint. Used only for inference.
print(f"Loading Reward Model from: {rm_model_path}")
reward_model = AutoModelForSequenceClassification.from_pretrained(
rm_model_path,
num_labels=1, # Output a single scalar reward score
torch_dtype=torch.bfloat16,
# low_cpu_mem_usage=True,
device_map={"": device} # Load to the target device
)
reward_model.eval() # Set to evaluation mode
print("Reward model loaded and set to eval mode.")
# Note on Reference Model:
# TRL's PPOTrainer typically handles the reference model automatically.
# It creates an internal, frozen copy of the `policy_model` at initialization.
# If building a custom PPO loop, you would load the SFT model explicitly again
# and ensure its weights remain frozen throughout training.
# --- Initialize PPOTrainer ---
# The PPOTrainer orchestrates the PPO training loop, managing the policy updates,
# KL divergence calculation against the reference model, and value function training.
print("Initializing PPOTrainer...")
ppo_trainer = PPOTrainer(
config=ppo_config,
model=policy_model, # The model to be trained (includes value head)
ref_model=None, # TRL handles this internally if None
tokenizer=tokenizer,
# dataset, data_collator, optimizer etc. need to be provided
# ... (other arguments for dataset, optimizer, etc.)
)
print("PPOTrainer initialized.")
# --- Ready for PPO Training Loop ---
# The system is now set up with the necessary models loaded and initialized.
# The next step involves the PPO training loop: generating responses with
# `policy_model`, scoring them with `reward_model`, and updating `policy_model`
# using the PPO algorithm managed by `ppo_trainer`.
Loading multiple large models is memory-intensive. Consider these points:
torch.bfloat16
or torch.float16
where possible to reduce memory footprint and potentially speed up computation on compatible hardware (like NVIDIA Ampere GPUs and newer)..to(device)
or the device_map
argument during loading. If using multiple GPUs, you might place the RM on a different GPU than the Policy/Value/Reference models to balance load.The following diagram illustrates how models from different stages are loaded and utilized within the PPO phase.
Model provenance and initialization flow within the RLHF pipeline. Arrows indicate data or weight transfer between stages and components used during PPO training.
By carefully managing the loading and initialization of these distinct models, ensuring they reside on the appropriate devices and are configured correctly (trainable vs. frozen), you establish the foundation for the PPO fine-tuning process that drives alignment with human preferences.
© 2025 ApX Machine Learning