When datasets grow too large to fit comfortably in system memory (RAM), standard data loading approaches using PyTorch's map-style Dataset
can become bottlenecks or outright impossible. Map-style datasets, which implement __getitem__
and __len__
, often assume the ability to randomly access any item by index and potentially load metadata for the entire dataset upfront. This chapter section details strategies specifically designed for handling these massive datasets, focusing on streaming data efficiently using IterableDataset
.
Consider a terabyte-scale image dataset stored across thousands of files. A standard Dataset
might try to build a list of all file paths and corresponding labels in its __init__
method. Even if the images themselves aren't loaded, this metadata alone could exceed available RAM. Furthermore, the random access requirement of __getitem__
might be inefficient if data needs to be read sequentially from large compressed files or database queries. Shuffling large map-style datasets also typically involves creating a shuffled list of indices covering the entire dataset size (N), which again requires significant memory for large N.
PyTorch provides an alternative: torch.utils.data.IterableDataset
. Instead of defining __getitem__
and __len__
, you implement the __iter__
method. This method should return an iterator that yields one sample at a time. This approach is fundamentally different; it treats the dataset as a stream of data rather than an indexable collection.
IterableDataset
is particularly well-suited for scenarios where:
Here’s a conceptual implementation reading samples line-by-line from a large file:
import torch
from torch.utils.data import IterableDataset, DataLoader
class LargeTextFileDataset(IterableDataset):
def __init__(self, file_path, tokenizer):
super().__init__()
self.file_path = file_path
self.tokenizer = tokenizer
def __iter__(self):
# The iterator is created here for each epoch/worker
file_iterator = open(self.file_path, 'r')
# Map applies the processing function to each line from the iterator
return map(self.tokenizer, file_iterator)
# Usage:
# tokenizer = lambda line: torch.tensor([int(x) for x in line.strip().split(',')])
# dataset = LargeTextFileDataset('very_large_data.csv', tokenizer)
# loader = DataLoader(dataset, batch_size=32)
#
# for batch in loader:
# # Process batch
# pass
In this example, open(self.file_path, 'r')
returns an iterator over the lines of the file. The map
function then lazily applies the tokenizer
to each line as it's requested by the DataLoader
. No attempt is made to load the entire file into memory.
When using DataLoader
with num_workers > 0
, each worker process gets a copy of the IterableDataset
instance. A crucial aspect is ensuring that each worker processes a distinct portion of the data stream to avoid duplication. If not handled correctly, every worker might start reading the same large file from the beginning, leading to redundant work and incorrect effective batch composition.
The standard way to address this is within the __iter__
method using torch.utils.data.get_worker_info()
:
import torch
import math
from torch.utils.data import IterableDataset, DataLoader, get_worker_info
class ShardedLargeFileDataset(IterableDataset):
def __init__(self, file_path, processor_fn):
super().__init__()
self.file_path = file_path
self.processor_fn = processor_fn
# Determine file size or number of lines/records if needed for sharding
# self.num_records = self._get_num_records(file_path) # Example helper
def _get_records_iterator(self):
# Replace this with logic to iterate over your specific data records/files
with open(self.file_path, 'r') as f:
for line in f:
yield line # Yield raw records
def __iter__(self):
worker_info = get_worker_info()
record_iterator = self._get_records_iterator()
if worker_info is None: # Single-process loading
worker_id = 0
num_workers = 1
else: # Multi-process loading
worker_id = worker_info.id
num_workers = worker_info.num_workers
# Basic worker sharding: each worker processes every Nth record
# More sophisticated sharding might involve byte offsets or file splitting
sharded_iterator = (record for i, record in enumerate(record_iterator) if i % num_workers == worker_id)
# Apply processing within the worker's iterator chain
processed_iterator = map(self.processor_fn, sharded_iterator)
return processed_iterator
# Example Usage:
# processor = lambda line: torch.tensor([float(x) for x in line.strip().split()])
# dataset = ShardedLargeFileDataset('massive_dataset.txt', processor)
# loader = DataLoader(dataset, batch_size=64, num_workers=4)
#
# for batch in loader:
# # Training step...
# pass
In this refined example, get_worker_info()
provides the id
of the current worker and the total num_workers
. The code then filters the base record_iterator
so that worker k
only processes records where index % num_workers == k
. This ensures each worker gets a unique, interleaved subset of the data stream. Note that more complex sharding (e.g., assigning entire files or byte ranges to workers) might be necessary depending on the data format and storage.
Shuffling IterableDataset
instances requires different strategies than map-style datasets. Since there's no global index, you cannot simply shuffle indices. Common approaches include:
DataLoader
doesn't provide this out-of-the-box for IterableDataset
, but libraries like torchdata
(part of the PyTorch domain libraries ecosystem) offer DataPipes with shuffling capabilities (e.g., shuffle
, sharding_filter
).IterableDataset
in each epoch.IterableDataset
to stream chunks of data (e.g., file paths or record identifiers) and a map-style Dataset
within each worker to load and process items from that chunk, allowing for shuffling within chunks.The choice depends on the scale of the data, the required level of randomness, and the overhead you can tolerate.
Regardless of whether you use map-style or iterable datasets, optimizing the data loading pipeline is significant for training performance, especially with large datasets where I/O can be a bottleneck.
webdataset
are designed for efficient streaming of large datasets stored as tar archives, often used with IterableDataset
.DataLoader
Parameters:
num_workers
: Setting num_workers > 0
enables multiprocessing for data loading. The optimal value depends on the CPU cores, batch size, data processing complexity, and I/O speed. A common starting point is the number of CPU cores available, but experimentation is needed. Too few workers will bottleneck on data loading; too many can cause overhead or thrash system resources.pin_memory=True
: If loading data onto the GPU, setting this to True
tells the DataLoader
to put fetched tensors into pinned (page-locked) memory. This enables faster asynchronous data transfer from CPU to GPU using tensor.to('cuda', non_blocking=True)
.prefetch_factor
(PyTorch 1.7+): Controls how many batches are prefetched by each worker. The default value (2) is often sufficient, but increasing it might help hide data loading latency if workers are sometimes slow.DataLoader
workers (inside the Dataset
's processing logic or via collate_fn
). This avoids storing multiple augmented versions of the data, saving disk space, especially important for large datasets.The following diagram illustrates how DataLoader
with multiple workers might handle an IterableDataset
using sharding:
Data flow with an
IterableDataset
and twoDataLoader
workers. The dataset provides iterators to each worker, sharded to ensure unique data processing per worker, enabling parallel data loading for the training loop.
By employing IterableDataset
, careful worker sharding, and optimizing the data loading pipeline, you can effectively train PyTorch models on datasets that vastly exceed the capacity of your system's memory, overcoming a significant hurdle in large-scale deep learning. These techniques are often combined with others discussed in this chapter, such as gradient accumulation, to manage both data size and computational constraints.
© 2025 ApX Machine Learning