Masterclass
Once a large language model is trained and deployed, the process doesn't stop. The information landscape constantly evolves, new data becomes available, and the distribution of data the model encounters in production might drift from its original training set. Simply freezing the model leads to staleness and degraded performance over time. Full retraining from scratch on combined old and new data is often prohibitively expensive in terms of computation and time. Continual pre-training offers a middle ground: updating the existing model with new data while attempting to retain previously learned knowledge.
The primary obstacle in continual pre-training is catastrophic forgetting. This phenomenon occurs when a model trained sequentially on different data distributions rapidly loses performance on earlier distributions as it adapts to the newer ones. The model's parameters shift significantly to minimize loss on the new data, effectively overwriting the representations important for performance on older data. Effectively managing this trade-off between learning new information (plasticity) and retaining old knowledge (stability) is the core challenge addressed by continual pre-training strategies.
The most straightforward approach is to simply continue the pre-training process using the new dataset, starting from the weights of the previously trained model. This is sometimes referred to as fine-tuning on new data.
# Example: Naive Continual Learning Setup
import torch
from torch.utils.data import DataLoader
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
AdamW,
get_linear_schedule_with_warmup,
)
# Assume 'model' is the pre-trained LLM loaded from a checkpoint
# Assume 'new_dataset' is a PyTorch Dataset object for the new data
model_path = "path/to/your/pretrained/llm"
model = AutoModelForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path) # Ensure tokenizer consistency
new_dataloader = DataLoader(new_dataset, batch_size=4, shuffle=True)
# Use a smaller learning rate than initial pre-training
optimizer = AdamW(model.parameters(), lr=1e-5) # Example LR
num_training_steps = len(new_dataloader) * num_epochs # Define num_epochs
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=0,
num_training_steps=num_training_steps
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.train()
for epoch in range(num_epochs):
for batch in new_dataloader:
optimizer.zero_grad()
# Assuming batch contains input_ids, attention_mask, labels
inputs = {k: v.to(device) for k, v in batch.items()}
outputs = model(**inputs)
loss = outputs.loss
loss.backward()
optimizer.step()
scheduler.step()
print(f"Epoch {epoch}, Loss: {loss.item()}")
# Save the updated model
model.save_pretrained("path/to/updated/llm")
While simple, this method is highly susceptible to catastrophic forgetting, especially if the new data distribution differs significantly from the old or if the volume of new data is large. The model's parameters will drift substantially towards minimizing the loss on the new data, potentially erasing capabilities learned earlier. This approach might only be viable if the new data is very similar to the original training data or represents a minor update.
Replay strategies explicitly combat forgetting by mixing data from previous training stages with the new data during the continual learning phase. By re-exposing the model to older examples, these methods encourage it to maintain performance on the original data distribution.
The core idea is to create training batches composed of both new data samples (Dnew​) and samples from the old dataset (Dold​).
# Example: Creating a Mixed DataLoader for Replay
from torch.utils.data import (DataLoader, Dataset, ConcatDataset,
WeightedRandomSampler)
# Assume 'old_dataset' and 'new_dataset' are PyTorch Dataset objects
# old_dataset might be a representative subset if the original is too large
# Option 1: Simple Concatenation
# (equal sampling probability if datasets are same size)
# combined_dataset = ConcatDataset([old_dataset, new_dataset])
# combined_dataloader = DataLoader(combined_dataset, batch_size=8, shuffle=True)
# Option 2: Weighted Sampling (control the mix ratio)
# Example: Aim for 75% new data, 25% old data per batch
new_data_weight = 0.75
old_data_weight = 0.25
weights = [new_data_weight / len(new_dataset)] * len(new_dataset) + \
[old_data_weight / len(old_dataset)] * len(old_dataset)
combined_dataset_for_sampler = ConcatDataset([new_dataset, old_dataset])
sampler = WeightedRandomSampler(
weights,
num_samples=len(combined_dataset_for_sampler),
replacement=True
)
# Note: Shuffle should be False when using a sampler
combined_dataloader = DataLoader(
combined_dataset_for_sampler,
batch_size=8,
sampler=sampler
)
# --- Training loop would use combined_dataloader ---
# for batch in combined_dataloader:
# # ... rest of the training step ...
Considerations for replay methods include:
Instead of relying on explicit data replay, regularization methods modify the loss function to penalize changes to model parameters deemed important for previous tasks or data distributions.
EWC estimates the importance of each parameter for the old data distribution using the Fisher Information Matrix (FIM), F. Parameters with high diagonal values in the FIM are considered more important. During training on new data (Dnew​), EWC adds a quadratic penalty term to the standard loss (Lnew​) that discourages changes to these important parameters (θi​) relative to their values after training on the old data (θold,i∗​).
The EWC loss is:
Ltotal​=Lnew​(θ)+2λ​i∑​Fi​(θi​−θold,i∗​)2Here, λ controls the strength of the regularization. A higher λ prioritizes retaining old knowledge.
Calculating the exact FIM is computationally expensive for LLMs. Practical implementations often use the diagonal approximation of the FIM, calculated based on gradients from the old data distribution. Even so, computing and storing these diagonal elements for billions of parameters requires careful implementation.
LwF uses knowledge distillation to preserve the behavior of the old model (Mold​) when training on new data (Dnew​). The idea is to encourage the new model (Mnew​) to produce similar outputs (e.g., logits or probability distributions over the vocabulary) as the old model when processing the new data.
The total loss combines the standard cross-entropy loss on the new data's true labels (ynew​) with a distillation loss (Ldistill​) that measures the difference between the outputs of Mnew​ and Mold​ on xnew​.
Ltotal​=LCE​(Mnew​(xnew​),ynew​)+λLdistill​(Mold​(xnew​),Mnew​(xnew​))The distillation loss is often implemented using KL divergence between the softened probability distributions (using a temperature T>1) from the two models.
# Example: LwF Loss Calculation
import torch.nn.functional as F
# Assume 'model' is the current model being trained (M_new)
# Assume 'old_model' is the frozen model from the previous stage (M_old)
# Assume 'batch' contains inputs and 'labels' for the new data
inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'}
labels = batch['labels'].to(device)
# Standard Cross-Entropy Loss on new data
outputs_new = model(**inputs)
logits_new = outputs_new.logits
loss_ce = F.cross_entropy(
logits_new.view(-1, logits_new.size(-1)),
labels.view(-1)
)
# Distillation Loss
temperature = 2.0 # Example temperature
lambda_distill = 0.5 # Example weight
with torch.no_grad():
outputs_old = old_model(**inputs)
logits_old = outputs_old.logits
# Softmax with temperature for both models
prob_new_soft = F.softmax(logits_new / temperature, dim=-1)
prob_old_soft = F.softmax(logits_old / temperature, dim=-1)
# KL Divergence loss (ensure proper dimensions)
loss_distill = F.kl_div(
F.log_softmax(logits_new / temperature, dim=-1)
.view(-1, logits_new.size(-1)),
prob_old_soft.view(-1, logits_old.size(-1)),
reduction='batchmean'
) * (temperature**2) # Scaling factor
# Total Loss
total_loss = (1.0 - lambda_distill) * loss_ce + \
lambda_distill * loss_distill
# --- Backpropagation would use total_loss ---
# total_loss.backward()
# optimizer.step()
LwF avoids the need to store old data or compute parameter importance matrices, making it computationally appealing compared to replay or EWC, especially for very large models. However, its effectiveness depends on the assumption that preserving the old model's predictions on new data is a good proxy for retaining knowledge relevant to the old data distribution.
While less common for pure continual pre-training of a single monolithic LLM, architectural methods involve modifying the model's structure. Techniques like Progressive Neural Networks freeze old network parts and add new columns for new tasks. Another approach involves using parameter-efficient modules like Adapters (discussed in Chapter 14). New adapters could potentially be trained for new data increments, aiming to isolate new knowledge within these small modules while keeping the backbone frozen. However, ensuring the model can effectively integrate knowledge across different adapters or increments remains an research area.
Regardless of the chosen strategy, several practical aspects are important:
Process flow for continual pre-training, highlighting the application of a strategy using new data and the essential evaluation steps on both old and new data distributions.
Often, the most effective approaches combine elements from different strategies. For instance, using a small replay buffer alongside EWC or LwF can provide complementary benefits. The optimal strategy depends heavily on the specific constraints (compute budget, storage capacity, data characteristics) and the desired balance between retaining old knowledge and adapting to new information. Continual pre-training remains an active area of research, particularly concerning scaling these techniques effectively to foundation models with trillions of parameters.
© 2025 ApX Machine Learning