Masterclass
Integrating new data sources into an existing large language model's training pipeline is essential for keeping the model up-to-date, expanding its knowledge domains, and potentially correcting biases or knowledge gaps discovered post-deployment. However, this process is not without risks. Simply adding new data can introduce noise, harmful content, or cause the model to forget previously learned information (catastrophic forgetting). A systematic and cautious approach is required to integrate new data safely and effectively.
Before incorporating any new dataset, it must undergo rigorous vetting, similar to the processes described in Chapter 7 for initial data preprocessing. The stakes can be even higher during continuous training, as regressions in model quality can impact live applications.
Once a new data source has been vetted, several strategies can be employed to integrate it into the continuous training process.
Simple Mixing and Resampling: The most straightforward approach is to add the cleaned new data to the existing training pool and continue training, resampling from the combined dataset. This requires careful consideration of the mixing ratio or source weights (Chapter 9). Giving too much weight to new data can accelerate forgetting, while too little weight might make learning from the new data inefficient. The optimal ratio often depends on the size and relevance of the new data relative to the existing corpus.
Data Rehearsal (Replay): To explicitly combat catastrophic forgetting, a common technique is rehearsal or replay. Instead of training solely on new data or a simple mix, each training batch is constructed using a combination of new data and a sampled subset of the old data. This forces the model to revisit previous knowledge while learning new information. The proportion of old versus new data in each batch becomes a critical hyperparameter. Sampling from the old data can be uniform or based on more sophisticated strategies, though uniform sampling is often effective and simpler to implement at scale.
Curriculum Learning: Introduce the new data gradually. This might involve starting with a low sampling weight for the new source and progressively increasing it over time (an annealing schedule, see Chapter 9). Alternatively, if the new data represents a distinctly different domain, one might structure the curriculum to first reinforce general knowledge with old data before heavily focusing on the new domain.
A simplified flow showing vetted old and new data sources being combined by a sampler before feeding into the continuous training process. Rehearsal involves sampling from the existing corpus.
Careful monitoring is essential when training with new data sources.
Managing multiple data sources often involves creating custom Datasets or Samplers in PyTorch. Here's an example using torch.utils.data.ConcatDataset
and WeightedRandomSampler
for simple mixing with source weighting.
import torch
from torch.utils.data import (Dataset, ConcatDataset, DataLoader,
WeightedRandomSampler)
# Assume OldDataset and NewDataset are PyTorch Dataset instances
# loaded with pre-processed, tokenized data paths or objects.
# old_data_paths = [...] # Paths to files/shards for old data
# new_data_paths = [...] # Paths to files/shards for new data
# class YourCustomDataset(Dataset):
# def __init__(self, data_paths):
# self.data_paths = data_paths
# # Initialize logic to load/access data items
# # self.index = self._build_index() # Example: map indices to file/offset
#
# def __len__(self):
# # Return total number of samples
# # return len(self.index)
# pass # Placeholder
#
# def __getitem__(self, idx):
# # Load and return tokenized sample corresponding to idx
# # sample = self._load_sample(self.index[idx])
# # return sample
# pass # Placeholder
# Replace with actual dataset implementations
class PlaceholderDataset(Dataset):
def __init__(self, num_samples):
self.num_samples = num_samples
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
# Simulate loading data
return {
"input_ids": torch.randint(0, 50000, (1024,)),
"labels": torch.randint(0, 50000, (1024,))
}
old_dataset = PlaceholderDataset(num_samples=1_000_000)
new_dataset = PlaceholderDataset(num_samples=200_000)
# Combine datasets
combined_dataset = ConcatDataset([old_dataset, new_dataset])
# Define sampling weights - e.g., sample new data more frequently
# than its size ratio
# Let's aim for new data to be ~30% of each batch,
# despite being <20% of total size.
old_data_weight = 0.7 / len(old_dataset)
new_data_weight = 0.3 / len(new_dataset)
sample_weights = torch.cat([
torch.full((len(old_dataset),), old_data_weight),
torch.full((len(new_dataset),), new_data_weight)
])
# Use WeightedRandomSampler
# 'replacement=True' is typical for large datasets to avoid iterating
# through all samples once per epoch.
# 'num_samples' defines the effective size of an epoch.
# Define how many samples constitute an "epoch" for scheduling purposes
effective_epoch_size = 500_000
sampler = WeightedRandomSampler(
sample_weights,
num_samples=effective_epoch_size,
replacement=True
)
# Create DataLoader
# Adjust batch_size, num_workers etc. based on hardware and
# distributed setup
batch_size = 8
data_loader = DataLoader(
combined_dataset,
sampler=sampler,
batch_size=batch_size,
num_workers=4
)
# Training loop using the data_loader
# for batch in data_loader:
# # Perform forward pass, backward pass, optimizer step
# # input_ids = batch['input_ids'].to(device)
# # labels = batch['labels'].to(device)
# # ... model training step ...
# pass # Placeholder
print(f"Combined dataset size: {len(combined_dataset)}")
print(f"Sampler using {len(sample_weights)} weights, sampling "
f"{effective_epoch_size} indices per epoch.")
# Example: First few weights reflect the lower probability for old data samples
print(f"Sample weights (first 5): {sample_weights[:5]}")
# Example: Weights towards the end reflect the higher probability
# for new data samples
print(f"Sample weights (last 5): {sample_weights[-5:]}")
Example PyTorch setup combining two datasets and using
WeightedRandomSampler
to control the mixing ratio during training. This approach implements simple mixing with source weighting. Implementing rehearsal requires a more complex sampler or dataset logic to explicitly fetch samples from both old and new sources within each batch.
If monitoring reveals problems like significant performance regressions or training instability after introducing new data:
Incorporating new data sources is a powerful tool for model evolution, but it demands a careful, methodical approach. Rigorous vetting, strategic integration, continuous monitoring, and the readiness to adjust or roll back are all necessary components for safely updating large language models in dynamic environments.
© 2025 ApX Machine Learning