Masterclass
As outlined in the chapter introduction, pre-trained language models, despite their impressive capabilities, often lack the specific conditioning needed to reliably follow user instructions or adhere to desired behavioral guidelines. They are trained to predict the next token in a sequence based on massive amounts of text, but this objective doesn't directly translate into helpfulness, honesty, or harmlessness as defined by human expectations. Supervised Fine-Tuning (SFT) is a technique designed to bridge this gap by explicitly teaching the model how to respond to prompts in a preferred manner.
SFT adapts a pre-trained LLM using a dataset composed of curated input prompts and their corresponding desired outputs. Think of it as providing the model with direct examples of how it should behave. Instead of learning from the implicit patterns in web-scale text, the model learns from explicit demonstrations of good responses. The process involves further training the pre-trained model on these supervised examples, typically using a standard sequence-to-sequence loss function like cross-entropy.
At its core, SFT refines the model's parameters by minimizing the difference between the model's generated output and the target output provided in the fine-tuning dataset. The process generally follows these steps:
<|prompt|> What is the capital of Malaysia? <|response|> The capital of Malaysia is Kuala Lumpur. <|endoftext|>
.desired_response
part of the sequence. The prompt tokens serve as context but do not contribute directly to the loss calculation or gradient updates.This targeted loss calculation is important. We want the model to learn how to generate the response given the prompt, not simply how to predict the prompt tokens themselves (which it already learned during pre-training).
Consider the objective function. During pre-training, the model maximizes the likelihood of the entire text corpus, P(text). In SFT, the model learns a conditional probability: given a specific prompt, it maximizes the likelihood of the desired response, P(response∣prompt). This shift focuses the model on generating appropriate outputs conditioned on instructional inputs.
We can visualize the basic flow of information during a single SFT training step:
A simplified representation of the SFT process, showing how prompts and desired responses from the dataset are used to calculate loss and update the pre-trained LLM's weights.
To implement the targeted loss calculation in practice using frameworks like PyTorch, we typically create a loss mask. This mask ensures that only the tokens corresponding to the desired response contribute to the loss computation.
Here's a PyTorch snippet illustrating this:
import torch
import torch.nn.functional as F
# Assume:
# - logits: Model output logits [batch_size, sequence_length, vocab_size]
# - labels: Target token IDs [batch_size, sequence_length]
# - prompt_lengths: Length of the prompt part for each item in the batch [batch_size]
# - IGNORE_INDEX: A special index ignored by the loss function (e.g., -100)
# Assume IGNORE_INDEX is defined globally, e.g.:
IGNORE_INDEX = -100
def calculate_sft_loss(logits, labels, prompt_lengths):
"""Calculates cross-entropy loss only on response tokens."""
batch_size, sequence_length, vocab_size = logits.shape
# Shift logits and labels for next-token prediction
# Logits for predicting token i are at index i-1
# Labels for token i are at index i
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Create a loss mask
# Initialize mask to 1 (calculate loss)
loss_mask = torch.ones_like(shift_labels, dtype=torch.bool)
# Set mask to 0 (ignore loss) for prompt tokens
for i in range(batch_size):
# Prompt length includes the initial token,
# so mask up to prompt_length - 1 in the shifted sequence
prompt_end_index = prompt_lengths[i] - 1
if prompt_end_index > 0: # Ensure there's a prompt part to mask
loss_mask[i, :prompt_end_index] = 0
# Apply the mask: Where mask is 0, set label to IGNORE_INDEX
masked_labels = shift_labels.masked_fill(~loss_mask, IGNORE_INDEX)
# Flatten the sequence dimension for loss calculation
shift_logits = shift_logits.view(-1, vocab_size)
masked_labels = masked_labels.view(-1)
# Calculate cross-entropy loss, ignoring IGNORE_INDEX
loss = F.cross_entropy(shift_logits,
masked_labels,
ignore_index=IGNORE_INDEX)
return loss
# --- Example Usage ---
# batch_size = 2
# sequence_length = 10
# vocab_size = 1000
# prompt_lengths = torch.tensor([3, 5]) # Prompt lengths for each item
# dummy_logits = torch.randn(batch_size,
# sequence_length,
# vocab_size,
# requires_grad=True)
# dummy_labels = torch.randint(0, vocab_size, (batch_size, sequence_length))
# sft_loss = calculate_sft_loss(dummy_logits, dummy_labels, prompt_lengths)
# print(f"Calculated SFT Loss: {sft_loss.item()}")
# sft_loss.backward() # Compute gradients
This code snippet outlines how to mask the loss computation, ensuring only the response tokens influence the model updates during SFT.
Supervised Fine-Tuning serves several important alignment objectives:
While SFT is effective for teaching the model what kind of response is desired based on examples, it doesn't inherently capture human preferences perfectly. It teaches the model to imitate the style and content of the provided responses. For more complex alignment goals, such as judging the relative quality between multiple plausible responses or optimizing for qualities like "helpfulness," SFT is often followed by techniques like Reinforcement Learning from Human Feedback (RLHF), which we will discuss in the next chapter. SFT provides a foundation, equipping the model with the basic ability to follow instructions before it's further refined using preference-based methods.
Was this section helpful?
© 2025 ApX Machine Learning