As discussed earlier in this chapter, Direct Preference Optimization (DPO) offers a compelling alternative to the multi-stage RLHF pipeline. Instead of training a separate reward model, DPO directly optimizes the language model policy using preference data. This section provides a hands-on guide to understanding and implementing the DPO loss function, a core component of this technique.
Recall that DPO aims to increase the relative log probability of preferred responses (yw) compared to rejected responses (yl) for a given prompt (x). It achieves this by implicitly defining a reward function based on the ratio of the policy model's (πθ) probability and a reference model's (πref) probability for a given completion. The objective function derived from this formulation is:
LDPO(πθ;πref)=−E(x,yw,yl)∼D[logσ(βlogπref(yw∣x)πθ(yw∣x)−βlogπref(yl∣x)πθ(yl∣x))]Let's break down the terms:
The term inside the logσ function can be rewritten as:
β((logπθ(yw∣x)−logπref(yw∣x))−(logπθ(yl∣x)−logπref(yl∣x)))This highlights that we are maximizing the difference between the log-probability ratios of the chosen response and the rejected response, scaled by β.
To implement this loss in a typical deep learning framework like PyTorch or TensorFlow, you need the log probabilities of the chosen (yw) and rejected (yl) sequences under both the policy model (πθ) being trained and the frozen reference model (πref).
Here are the steps to compute the loss for a batch of preference data:
Obtain Log Probabilities: Perform forward passes for the prompts (x) and both completions (yw, yl) through both the policy model and the reference model. This yields four sets of log probabilities for each sample in the batch:
policy_chosen_logps
: logπθ(yw∣x)policy_rejected_logps
: logπθ(yl∣x)ref_chosen_logps
: logπref(yw∣x)ref_rejected_logps
: logπref(yl∣x)
Remember that the reference model πref is not updated during training; its parameters remain fixed, and gradients are not computed for it. The policy model πθ is the one whose parameters are being optimized.Calculate Log Ratios: Compute the log probability ratios relative to the reference model for both chosen and rejected responses:
log_ratio_w = policy_chosen_logps - ref_chosen_logps
log_ratio_l = policy_rejected_logps - ref_rejected_logps
Calculate the Difference: Find the difference between these log ratios and scale by β:
diff = beta * (log_ratio_w - log_ratio_l)
Apply Logistic Loss: Compute the negative log-sigmoid of the difference. This is the core DPO loss for each sample. Using standard library functions like logsigmoid
helps maintain numerical stability.
loss_per_sample = -torch.nn.functional.logsigmoid(diff)
(in PyTorch) or equivalent. Note that −logσ(z) is equivalent to softplus(−z).Average the Loss: Compute the mean of loss_per_sample
across the batch to get the final loss value for the training step.
Below is a Python function using PyTorch that demonstrates the DPO loss calculation, assuming you have already computed the necessary log probabilities.
import torch
import torch.nn.functional as F
def compute_dpo_loss(policy_chosen_logps: torch.Tensor,
policy_rejected_logps: torch.Tensor,
ref_chosen_logps: torch.Tensor,
ref_rejected_logps: torch.Tensor,
beta: float) -> torch.Tensor:
"""
Computes the Direct Preference Optimization (DPO) loss.
Args:
policy_chosen_logps: Log probabilities of the chosen responses
under the policy model. Shape: (batch_size,)
policy_rejected_logps: Log probabilities of the rejected responses
under the policy model. Shape: (batch_size,)
ref_chosen_logps: Log probabilities of the chosen responses
under the reference model. Shape: (batch_size,)
ref_rejected_logps: Log probabilities of the rejected responses
under the reference model. Shape: (batch_size,)
beta: Temperature parameter controlling the deviation from the
reference model.
Returns:
The average DPO loss over the batch.
"""
# Calculate the log probability ratios for chosen and rejected responses
# pi_logratios = policy_chosen_logps - policy_rejected_logps # Not directly used in formula like this
# ref_logratios = ref_chosen_logps - ref_rejected_logps # Not directly used in formula like this
# Calculate log ratios referenced to the base model (pi_ref)
log_ratio_chosen = policy_chosen_logps - ref_chosen_logps
log_ratio_rejected = policy_rejected_logps - ref_rejected_logps
# Calculate the difference, scaled by beta
# This term represents beta * (reward_chosen - reward_rejected)
# where reward is implicitly defined as log(pi_policy / pi_ref)
diff = beta * (log_ratio_chosen - log_ratio_rejected)
# Calculate the loss using the negative log-sigmoid function
# loss = -log(sigmoid(diff)) = softplus(-diff)
# Using log_sigmoid for numerical stability: log_sigmoid(x) = -softplus(-x)
# So, loss = -log_sigmoid(diff)
loss = -F.logsigmoid(diff)
# Average the loss over the batch
average_loss = loss.mean()
return average_loss
# --- Usage ---
# Assume these tensors come from forward passes of your models
# (e.g., using model.forward(input_ids, labels=labels).logits)
# Typically, log probabilities are summed over the sequence length for each response.
batch_size = 8
# Example log probabilities (ensure they are properly calculated in practice)
policy_chosen_logps = torch.tensor([-10.5, -12.1, -9.8, -11.0, -13.5, -10.1, -11.8, -12.5], requires_grad=True)
policy_rejected_logps = torch.tensor([-11.2, -11.9, -10.5, -11.5, -12.8, -10.9, -12.3, -13.0], requires_grad=True)
ref_chosen_logps = torch.tensor([-10.2, -11.8, -9.5, -10.7, -13.0, -9.8, -11.5, -12.1]) # No gradients needed
ref_rejected_logps = torch.tensor([-11.0, -11.5, -10.1, -11.1, -12.2, -10.5, -11.9, -12.5]) # No gradients needed
beta_value = 0.1
# Compute the DPO loss
dpo_loss_value = compute_dpo_loss(policy_chosen_logps, policy_rejected_logps,
ref_chosen_logps, ref_rejected_logps, beta_value)
print(f"Computed DPO Loss: {dpo_loss_value.item():.4f}")
# --- In a training loop ---
# optimizer.zero_grad()
# dpo_loss_value.backward() # Compute gradients for policy model parameters
# optimizer.step()
ref_model.eval()
in PyTorch) to disable dropout or other training-specific behaviors during the forward pass for log probability calculation.transformers
often provide ways to get sequence likelihoods.F.logsigmoid
prevents potential underflow/overflow issues that might arise from calculating log(1 / (1 + exp(-diff)))
directly.Consider adapting a standard language model fine-tuning script (perhaps one using Hugging Face transformers
). Modify the training loop to incorporate the DPO loss calculation. You will need:
policy_model
) and one for the reference (ref_model
). Ensure ref_model
is frozen.compute_dpo_loss
function (or similar) integrated into your training step.Libraries like Hugging Face's trl
offer pre-built DPOTrainer
classes that abstract away much of this complexity, but implementing the core loss yourself provides a deeper understanding of the underlying mechanics.
© 2025 ApX Machine Learning