Masterclass
Supervised Fine-Tuning (SFT) adapts a pre-trained large language model to follow instructions or generate responses in a specific style by training it on a dataset of high-quality prompt-response examples. Unlike pre-training, which focuses on next-token prediction over vast, unstructured text, SFT is a more targeted form of training aimed at aligning the model's behavior with desired outcomes. The training process itself resembles standard supervised learning for sequence-to-sequence tasks but involves specific considerations regarding data formatting, loss calculation, and hyperparameter selection.
The core SFT training loop iterates through batches of prompt-response pairs, performs a forward pass, calculates a loss based on the difference between the model's predictions and the target response, and updates the model's weights via backpropagation.
Data Preparation: Each training example typically consists of a prompt (e.g., an instruction or a user query) and a desired response (e.g., the answer to the instruction or a helpful reply). These are often concatenated into a single sequence, sometimes with special tokens separating the prompt and response sections or indicating the start/end of turns in a dialogue.
# Example pseudo-code for formatting
prompt = "Instruction: Explain the process of photosynthesis."
response = "Photosynthesis is the process used by plants..."
# Add special tokens if needed by the model/tokenizer
tokenizer.add_special_tokens({'pad_token': '[PAD]', 'eos_token': '[EOS]'})
# Simple concatenation example
input_text = prompt + " " + response + tokenizer.eos_token
# Tokenize the combined text
tokenized_input = tokenizer(
input_text,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=512
)
inputs = tokenized_input["input_ids"]
attention_mask = tokenized_input["attention_mask"]
Forward Pass: The tokenized sequence is fed into the model to obtain logits for each token position in the sequence.
# Assuming 'model' is your pre-trained transformer model
# and 'inputs'/'attention_mask' are from the previous step
outputs = model(input_ids=inputs, attention_mask=attention_mask)
logits = outputs.logits
Loss Calculation (Masking the Prompt): This is a distinguishing feature of SFT. The goal is to teach the model to generate the response given the prompt. Therefore, the loss is typically calculated only on the response tokens. The tokens corresponding to the prompt are masked out so they don't contribute to the loss calculation or gradient updates. The standard cross-entropy loss is used on the unmasked (response) tokens.
import torch
import torch.nn.functional as F
# logits: [batch_size, sequence_length, vocab_size]
# labels: [batch_size, sequence_length] (should be inputs shifted left)
labels = inputs.clone()
# Typically, labels are shifted so model predicts next token
logits = logits[:, :-1, :] # Drop last logit
labels = labels[:, 1:] # Drop first token (e.g., BOS)
# Determine prompt length for each item in batch
# This requires knowing where the prompt ends and response begins
# For simplicity, assume prompt_length is known for each example
# prompt_lengths: [batch_size] tensor with length of prompt tokens for each example
# Create loss mask: -100 is ignored by PyTorch CrossEntropyLoss
loss_mask = torch.ones_like(labels, dtype=torch.long)
for i in range(labels.shape[0]):
# Mask out prompt tokens (and padding tokens if tokenizer.pad_token_id exists)
prompt_end_index = prompt_lengths[i] - 1 # Adjust based on how length is defined
loss_mask[i, :prompt_end_index] = -100
if hasattr(tokenizer, 'pad_token_id') and tokenizer.pad_token_id is not None:
loss_mask[i][labels[i] == tokenizer.pad_token_id] = -100 # Mask padding
# Flatten logits and labels for CrossEntropyLoss, applying the mask
# Only compute loss on tokens where loss_mask is not -100
active_loss = loss_mask.view(-1) != -100
active_logits = logits.view(-1, logits.size(-1))[active_loss]
active_labels = labels.view(-1)[active_loss]
loss = F.cross_entropy(active_logits, active_labels)
In practice, libraries like Hugging Face's transformers.Trainer
or TRL (trl.SFTTrainer
) handle this masking logic internally if data is formatted correctly (e.g., using specific dataset formats or providing a data collator).
Backward Pass and Optimization: Standard backpropagation computes gradients based on the calculated loss. An optimizer, typically AdamW, updates the model weights.
# Using standard PyTorch optimization steps
optimizer.zero_grad()
loss.backward()
# Optional: Gradient Clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
if scheduler is not None:
scheduler.step()
Choosing appropriate hyperparameters is significant for successful SFT. Since SFT adapts an already powerful pre-trained model, the settings differ from those used during pre-training.
Learning Rate: SFT typically uses a much smaller learning rate than pre-training. Values in the range of 1×10−5 to 5×10−5 are common. The smaller rate prevents catastrophic forgetting of the knowledge acquired during pre-training while allowing adaptation to the new instruction-following objective. A learning rate scheduler, like cosine decay with a short warmup phase (e.g., 0-10% of total steps), is often beneficial.
A typical learning rate schedule for SFT, involving warmup and cosine decay.
Batch Size: The batch size is often constrained by GPU memory. Larger models require more memory, limiting the number of sequences per batch. Gradient accumulation is commonly used to achieve a larger effective batch size without increasing memory usage per GPU. Typical effective batch sizes might range from 64 to 1024, depending on the model size and available hardware.
Number of Epochs: SFT usually requires only a few epochs (often 1-3, sometimes up to 5). Training for too long can lead to overfitting on the specific SFT dataset, potentially reducing the model's ability to generalize to unseen instructions or degrading its broad knowledge. Monitoring performance on a validation set is important.
Optimizer: AdamW remains the standard choice, similar to pre-training. The weight decay parameter might be kept the same or slightly adjusted (e.g., 0.01 to 0.1). Beta parameters (β1​,β2​) are often kept at their defaults (e.g., 0.9, 0.999).
Sequence Length: The maximum sequence length should accommodate the combined length of typical prompts and responses in the SFT dataset. It might differ from the sequence length used during pre-training. Packing multiple short examples into one sequence or using dynamic padding can improve efficiency.
Gradient Clipping: Applying gradient clipping (e.g., clipping the L2 norm of gradients to 1.0) helps stabilize training, although instabilities are generally less frequent in SFT compared to large-scale pre-training.
The SFT process fine-tunes the model's capabilities, steering its pre-trained knowledge towards specific interaction patterns and task execution formats defined by the instruction dataset. Careful management of the training loop and hyperparameters ensures this adaptation happens effectively without compromising the model's underlying strengths.
© 2025 ApX Machine Learning