Masterclass
Once you have assembled your diverse collection of text data from sources like web crawls (e.g., Common Crawl), books, code repositories, and specialized corpora, the next challenge is deciding how much data from each source should contribute to the training process. Simply concatenating everything and shuffling uniformly might seem straightforward, but it rarely yields the best results. Different data sources possess varying levels of quality, relevance, and linguistic style. Source weighting allows you to exert fine-grained control over the composition of your training data mixture, guiding the model to prioritize learning from certain types of text over others.
The core idea is to assign a specific weight or probability to each data source. During training, when assembling a batch of data, examples are sampled from the different sources according to these predefined weights. Sources assigned higher weights will contribute more examples to the training process over time compared to sources with lower weights.
Let's assume you have k distinct data sources, S1,S2,…,Sk. Each source Si contains Ni documents or examples. We want to define a probability pi of sampling an example from source Si at any given step. A common approach is to make this probability proportional to both the size of the source and an assigned weight wi:
pi∝wi⋅NiAlternatively, and perhaps more directly controllable, you can define the desired proportion (or probability) pi for each source directly, such that ∑i=1kpi=1. The weights wi then simply become these target proportions pi.
For instance, you might decide on the following mixture:
This means that, on average, 60% of the examples seen by the model in each training epoch (or over a large number of steps) will come from the Web Text dataset, 15% from Books, and so on.
Example proportions for sampling from different data sources during LLM pre-training.
Choosing the right weights is often guided by a combination of factors and requires careful consideration and experimentation:
Implementing source weighting typically involves modifying the data loading process. Instead of sampling uniformly from a single large dataset, the data loader needs to be aware of the different sources and their associated weights.
In PyTorch, you can achieve this by creating a custom IterableDataset
or using samplers that accommodate weighted sampling across multiple underlying datasets. A simplified example might look like this:
import torch
import numpy as np
from torch.utils.data import IterableDataset, DataLoader
# Assume these are placeholder datasets for different sources
# In reality, these would load actual tokenized data
class DummyDataset(IterableDataset):
def __init__(self, source_name, size):
self.source_name = source_name
self.size = size
def __iter__(self):
for i in range(self.size):
# Yield dummy data: (example_data, source_identifier)
yield torch.randn(512), self.source_name
if i % 1000 == 0 and i > 0 : # Simulate potentially large dataset
print(f"Yielded {i} from {self.source_name}")
# Define sources and their desired sampling probabilities (weights)
sources = {
"web": {"dataset": DummyDataset("web", 1_000_000), "weight": 0.60},
"books": {"dataset": DummyDataset("books", 200_000), "weight": 0.15},
"code": {"dataset": DummyDataset("code", 300_000), "weight": 0.15},
"wiki": {"dataset": DummyDataset("wiki", 100_000), "weight": 0.10},
}
source_names = list(sources.keys())
source_weights = np.array([sources[name]["weight"] for name in source_names])
# Normalize weights to ensure they sum to 1 (optional if already normalized)
# source_weights /= source_weights.sum()
source_iters = {name: iter(sources[name]["dataset"]) for name in source_names}
class WeightedSourceSampler(IterableDataset):
def __init__(self, source_names, source_weights, source_iters):
self.source_names = source_names
self.source_weights = source_weights
self.source_iters = source_iters
def __iter__(self):
while True:
# Choose a source based on weights
chosen_source_name = np.random.choice(
self.source_names, p=self.source_weights
)
try:
# Get the next item from the chosen source's iterator
item = next(self.source_iters[chosen_source_name])
yield item
except StopIteration:
# If a source iterator is exhausted, recreate it (or handle as needed)
print(f"Restarting iterator for {chosen_source_name}")
self.source_iters[chosen_source_name] = iter(
sources[chosen_source_name]["dataset"]
)
# Optionally, break or implement logic for finite datasets
# For large-scale pre-training, iterators often cycle indefinitely
# Re-attempt fetching after reset:
try:
item = next(self.source_iters[chosen_source_name])
yield item
except StopIteration:
print(
f"Warning: Iterator for {chosen_source_name} "
f"immediately exhausted after reset."
)
# Decide how to handle this case, e.g., re-sample source or raise error
# For this example, we might just skip this turn.
continue
# Create the combined dataset sampler
weighted_sampler_dataset = WeightedSourceSampler(
source_names, source_weights, source_iters
)
# Use it with a DataLoader
# Note: num_workers > 0 requires careful handling with
# IterableDatasets and iterators
# For simplicity, using num_workers=0 here.
# Real implementations need robust multi-worker support.
data_loader = DataLoader(weighted_sampler_dataset, batch_size=4, num_workers=0)
# Example of fetching a batch
print("Fetching a batch...")
batch = next(iter(data_loader))
# batch[0] contains the data tensors, batch[1] contains source names
print(f"Batch data shape: {batch[0].shape}")
print(f"Batch source identifiers: {batch[1]}")
# Simulate fetching a few more batches to see source distribution
print("\nFetching more batches...")
for i in range(5):
batch = next(iter(data_loader))
print(f"Batch {i+1} sources: {batch[1]}")
PyTorch implementation sketch for sampling from multiple data sources based on predefined weights.
This example demonstrates the basic principle. Production-grade data loaders for LLMs often involve more sophisticated mechanisms for efficiency, handling distributed training, and managing extremely large datasets that don't fit in memory, potentially using streaming techniques.
While powerful, source weighting introduces complexity:
Source weighting is a fundamental technique for managing the composition of the massive datasets used in LLM pre-training. By carefully considering the quality, relevance, and volume of different data sources and assigning appropriate weights, engineers can better guide the learning process and shape the capabilities of the final model.
© 2025 ApX Machine Learning