Masterclass
Once your massive text dataset is cleaned, processed, and stored, likely across a distributed file system or cloud storage (as discussed in "Distributed File Systems (HDFS, S3)"), the challenge shifts to efficiently feeding this data into your distributed training setup. Loading multi-terabyte datasets entirely into memory is impossible. Even reading them fully from disk before training starts is often impractical. This is where streaming data loaders become essential.
Instead of loading the entire dataset upfront (map-style datasets), streaming data loaders read data dynamically, sample by sample or chunk by chunk, directly from storage during the training process. This approach keeps memory usage low and allows training to begin almost immediately, but it introduces new complexities, particularly around I/O performance, data shuffling, and coordination in a distributed environment.
Modern GPUs can process data incredibly quickly. If your data loader cannot supply training batches faster than the GPU can consume them, the expensive accelerators will sit idle, wasting compute resources and extending training time significantly. This is the I/O bottleneck problem.
A streaming data loader must be designed to:
PyTorch's torch.utils.data.DataLoader
provides built-in support for parallel workers (num_workers
) and prefetching (prefetch_factor
), which are indispensable for streaming. However, the core logic of how data is fetched and processed sample-by-sample resides within an IterableDataset
.
PyTorch distinguishes between map-style datasets (implementing __getitem__
and __len__
) and iterable-style datasets (implementing __iter__
). For streaming large datasets, IterableDataset
is the natural choice.
An IterableDataset
's __iter__
method is responsible for yielding one processed sample at a time. When used with a DataLoader
and multiple workers (num_workers > 0
), PyTorch handles distributing the workload. Each worker gets its own iterator instance, typically configured to process a distinct subset of the data shards.
Here's a skeleton of an IterableDataset
designed for streaming:
import torch
import os
import random
from torch.utils.data import IterableDataset, DataLoader
class StreamingTextDataset(IterableDataset):
def __init__(self, data_dir, shard_pattern="shard_*.txt",
shuffle_shards=True):
super().__init__()
self.data_dir = data_dir
self.shard_files = sorted([
os.path.join(data_dir, f)
for f in os.listdir(data_dir)
if f.endswith(".txt") # Simplified matching
])
self.shuffle_shards = shuffle_shards
def _iter_shard(self, shard_file):
# In practice, use efficient reading and parsing
# (e.g., for Arrow/Parquet)
worker_info = torch.utils.data.get_worker_info()
worker_id = worker_info.id if worker_info else 0
print(f"Worker {worker_id}: Processing {shard_file}")
try:
with open(shard_file, 'r', encoding='utf-8') as f:
for line in f:
# Simulate processing: yield processed sample
# (e.g., tokenized IDs)
# Replace with actual tokenization/processing
processed_sample = line.strip()
if processed_sample:
yield processed_sample
except Exception as e:
print(f"Error processing shard {shard_file}: {e}")
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
current_shards = self.shard_files
if self.shuffle_shards:
# Shuffle shard order consistently across epochs
# but differently each epoch
# Use epoch number for seeding if available via
# DataLoader worker_init_fn
g = torch.Generator()
# Simplified seeding
seed = int(torch.randint(0, 10000, (1,)).item())
g.manual_seed(seed)
shard_indices = torch.randperm(
len(current_shards), generator=g
).tolist()
current_shards = [self.shard_files[i] for i in shard_indices]
if worker_info is None:
# Single-process loading
worker_id = 0
num_workers = 1
shards_for_worker = current_shards
else:
# Multi-process loading
worker_id = worker_info.id
num_workers = worker_info.num_workers
# Assign shards to workers (simple round-robin distribution)
shards_for_worker = [
s for i, s in enumerate(current_shards)
if i % num_workers == worker_id
]
print(
f"Worker {worker_id}/{num_workers}: "
f"Assigned {len(shards_for_worker)} shards."
)
for shard_file in shards_for_worker:
yield from self._iter_shard(shard_file)
# Usage example
# Assume 'my_large_dataset_shards/' contains files like shard_001.txt,
# shard_002.txt, ...
# dataset = StreamingTextDataset(
# data_dir='my_large_dataset_shards/'
# )
# dataloader = DataLoader(
# dataset,
# batch_size=32,
# num_workers=4,
# prefetch_factor=2
# )
#
# for epoch in range(3):
# print(f"\n--- Epoch {epoch} ---")
# for batch in dataloader:
# # model training step with batch
# # print(f"Received batch of size: {len(batch)}")
# # batch would be list of strings here
# pass # Replace with training logic
In this example:
__init__
identifies the data shards (e.g., individual files).__iter__
is called for each worker. It determines which shards this specific worker should process based on worker_info
. It optionally shuffles the order of shards._iter_shard
handles reading and yielding samples from a single shard.True random shuffling requires loading the entire dataset, which is impossible. Streaming loaders typically use approximate shuffling techniques.
The shuffle buffer reads data sequentially but yields samples randomly from a fixed-size memory buffer, providing better local shuffling than just shuffling shards.
The size of the shuffle buffer (M) is a trade-off: larger buffers provide better randomness but consume more memory per worker.
While you can build streaming loaders using raw PyTorch IterableDataset
, several libraries offer pre-built, highly optimized solutions specifically for large-scale deep learning:
webdataset
: Excellent for data stored as sequences of files, especially .tar
archives. It provides flexible pipelines for decoding, augmenting, and shuffling data streams. It's widely used for multimodal datasets but works well for text too.StreamingDataset
: Designed for cloud-native training. It stores data in its own optimized format (MDS) on object storage (like S3). It handles efficient sharding, shuffling (using shard shuffling and shuffle buffer concepts), and seamless resumption across workers and nodes automatically. It requires converting your dataset to MDS format first.datasets
: The popular library now includes streaming capabilities (load_dataset(..., streaming=True)
). It allows iterating over large datasets directly from the Hugging Face Hub or local storage without downloading everything first. It integrates well with their tokenizers and models.Using these libraries can significantly simplify development and often provides better performance due to specialized optimizations for I/O, caching, and shuffling.
Long LLM training jobs inevitably face interruptions (hardware failures, preemption). A streaming data loader must support resumption, meaning it can restart exactly where it left off in the data stream.
This requires checkpointing the state of the data loader alongside the model weights and optimizer state (see Chapter 19: "Checkpointing and Fault Tolerance"). The state typically includes:
Libraries like StreamingDataset
manage this state automatically. If building a custom loader, you need to explicitly collect this state from all workers during checkpointing and restore it when resuming.
num_workers
: Set this to utilize multiple CPU cores for data loading. The optimal value depends on the CPU, storage speed, and data processing complexity. Start with the number of physical CPU cores available per GPU and experiment.prefetch_factor
: Controls how many batches are preloaded per worker. A value of 2
is often a good starting point.StreamingDataset
are optimized for the latter.Effectively managing and streaming data is a critical infrastructure component for training large language models. By choosing appropriate storage formats, leveraging distributed file systems, and implementing efficient, resumable streaming data loaders, you can ensure your GPUs are consistently fed with data, enabling successful training at scale.
Was this section helpful?
© 2025 ApX Machine Learning