Masterclass
While randomly shuffling and sampling from a diverse data mixture is a common approach, it treats all data points as equally informative from the start. Think about how humans learn: we typically start with simpler concepts and gradually build up to more complex ones. We don't jump straight into advanced calculus before understanding basic arithmetic. Curriculum Learning (CL) applies a similar principle to training machine learning models, including LLMs. Instead of presenting data points in a purely random order drawn from the entire dataset, CL introduces a structure, often moving from "easier" examples to "harder" ones over the course of training.
The underlying idea is that starting with simpler examples can help the model establish foundational representations and avoid getting stuck in poor local minima early in training. This initial grounding might make it easier for the model to subsequently learn from more complex or noisy data. In the context of LLM pre-training, the definition of "easy" versus "hard" can take several forms.
What constitutes an "easy" or "hard" example for an LLM? This is not always straightforward, but common approaches include:
Implementing CL requires modifying the data loading or sampling process. Instead of sampling uniformly from the dataset, the sampler needs to be aware of the training progress (e.g., current epoch or step) and select data according to the defined curriculum schedule.
A simple approach might involve bucketing the data based on a difficulty metric (like sequence length) and controlling which buckets are actively sampled from at different stages of training.
Consider a basic length-based curriculum implemented via a custom PyTorch Sampler
. This example demonstrates the core logic, not a production-ready implementation.
import torch
from torch.utils.data import Sampler
import numpy as np
class LengthBasedCurriculumSampler(Sampler):
def __init__(self,
data_lengths,
batch_size,
start_percentile=0.1,
end_percentile=1.0,
total_steps=10000):
"""
Samples batches based on increasing sequence length percentile
over training.
Args:
data_lengths (list or np.array): List of lengths for each
data sample.
batch_size (int): The size of each batch.
start_percentile (float): Initial length percentile threshold
(0.0 to 1.0).
end_percentile (float): Final length percentile threshold
(0.0 to 1.0).
total_steps (int): Total number of training steps over which
the curriculum progresses.
"""
self.data_lengths = np.array(data_lengths)
self.indices = np.argsort(self.data_lengths) # Indices sorted by length
self.sorted_lengths = self.data_lengths[self.indices]
self.batch_size = batch_size
self.start_percentile = start_percentile
self.end_percentile = end_percentile
self.total_steps = total_steps
self.current_step = 0
self.num_samples = len(data_lengths)
# Calculate initial and final indices based on percentiles
self.start_idx = int(self.start_percentile * self.num_samples)
self.final_max_idx = int(self.end_percentile * self.num_samples)
def get_current_max_index(self):
# Linearly increase the maximum index allowed over total_steps
progress = min(1.0, self.current_step / self.total_steps)
increase = progress * (self.final_max_idx - self.start_idx)
current_max_idx = int(self.start_idx + increase)
# Ensure we always include at least the starting percentile of data
return max(self.start_idx, current_max_idx)
def __iter__(self):
current_max_idx = self.get_current_max_index()
# Eligible indices are those up to the current maximum length threshold
eligible_indices = self.indices[:current_max_idx]
if len(eligible_indices) < self.batch_size:
# Handle cases where eligible data is too small (e.g., early steps)
# Might repeat samples or use a smaller batch
eligible_indices = np.random.choice(
eligible_indices, size=self.batch_size, replace=True
)
else:
# Shuffle the eligible indices for the current epoch/step
np.random.shuffle(eligible_indices)
# Yield batches (simplified batching logic)
num_batches = 0
for i in range(0, len(eligible_indices), self.batch_size):
batch_indices = eligible_indices[i : i + self.batch_size]
# Drop last incomplete batch for simplicity
if len(batch_indices) == self.batch_size:
yield batch_indices.tolist()
num_batches += 1
# Increment step after yielding all batches for this iteration
# In a real trainer, step update would happen per optimizer step
# This simplified version increments once per __iter__ call
# Rough step increment
self.current_step += num_batches
def __len__(self):
# Estimated number of batches per epoch/iteration
current_max_idx = self.get_current_max_index()
num_eligible = len(self.indices[:current_max_idx])
return num_eligible // self.batch_size
# --- Usage Example ---
# Assume `dataset` is your PyTorch Dataset object
# Assume `lengths` is a list containing the length of each item in `dataset`
# lengths = [len(item) for item in dataset] # Precompute lengths
#
# total_training_steps = 50000 # Example total steps
# batch_size = 32
#
# sampler = LengthBasedCurriculumSampler(
# lengths, batch_size, total_steps=total_training_steps
# )
# dataloader = torch.utils.data.DataLoader(
# dataset, batch_size=None, sampler=sampler # batch_size=None with sampler
# )
#
# # Training loop would use this dataloader
# # for epoch in range(num_epochs):
# # for batch in dataloader:
# # # Training step...
# # # Update sampler's internal step if needed,
# # # although this example updates per iter
This sampler sorts the data by length once at initialization. In each iteration (which typically corresponds to an epoch), it determines the maximum permissible data index based on the current training progress (current_step
). It then shuffles and yields batches only containing data points up to that length percentile. The get_current_max_index
function defines the pacing of the curriculum.
Potential benefits of CL include:
However, CL also presents challenges:
While explicit, complex curricula based on fine-grained difficulty metrics are not always the default for training the largest LLMs (where sophisticated data mixture weighting is often preferred for its scalability and empirical success), the core idea of curriculum learning often informs how these mixtures are designed and potentially sequenced. For instance, a multi-stage training process where the model is first trained on cleaner data before being exposed to the full, noisier dataset can be seen as a form of coarse-grained curriculum. Understanding the principles of CL provides another tool for optimizing the demanding process of LLM training.
Was this section helpful?
© 2025 ApX Machine Learning