趋近智
当数据集大到无法在系统内存(RAM)中轻松容纳时,使用 PyTorch 映射式 Dataset 的标准数据加载方法会成为瓶颈或完全不可行。映射式数据集实现了 __getitem__ 和 __len__,通常假设可以按索引随机访问任何项目,并且可能需要预先加载整个数据集的元数据。这里说明了专门为处理这些巨型数据集而设计的策略,侧重于使用 IterableDataset 高效地传输数据。
想象一个存储在数千个文件中的 TB 级图像数据集。一个标准 Dataset 可能会尝试在其 __init__ 方法中构建所有文件路径和对应标签的列表。即使图像本身没有加载,仅这些元数据就可能超出可用内存。此外,如果数据需要从大型压缩文件或数据库查询中按顺序读取,__getitem__ 的随机访问要求可能效率低下。打乱大型映射式数据集通常也涉及创建一个覆盖整个数据集大小 (N) 的打乱索引列表,这对于大型 N 来说同样需要大量内存。
PyTorch 提供了一种替代方案:torch.utils.data.IterableDataset。您无需定义 __getitem__ 和 __len__,而是实现 __iter__ 方法。此方法应返回一个迭代器,每次生成一个样本。此方法根本不同;它将数据集视为数据流,而非可索引的集合。
IterableDataset 特别适合以下情况:
这是一个从大文件中逐行读取样本的实现方式:
import torch
from torch.utils.data import IterableDataset, DataLoader
class LargeTextFileDataset(IterableDataset):
def __init__(self, file_path, tokenizer):
super().__init__()
self.file_path = file_path
self.tokenizer = tokenizer
def __iter__(self):
# 迭代器在此处为每个 epoch/worker 创建
file_iterator = open(self.file_path, 'r')
# map 函数将处理函数应用于迭代器中的每一行
return map(self.tokenizer, file_iterator)
# 用法:
# tokenizer = lambda line: torch.tensor([int(x) for x in line.strip().split(',')])
# dataset = LargeTextFileDataset('very_large_data.csv', tokenizer)
# loader = DataLoader(dataset, batch_size=32)
#
# for batch in loader:
# # 处理批次数据
# pass
在此示例中,open(self.file_path, 'r') 返回一个遍历文件行的迭代器。然后 map 函数在 DataLoader 请求时,对每一行进行延迟处理(应用 tokenizer)。没有尝试将整个文件加载到内存中。
当使用 DataLoader 且 num_workers > 0 时,每个工作进程会获得 IterableDataset 实例的一个副本。一个重要方面是,需要确保每个工作进程处理数据流中不同的部分,以避免重复。如果处理不当,每个工作进程都可能从头开始读取同一个大文件,导致重复工作和不正确的有效批次构成。
解决此问题的标准方法是在 __iter__ 方法中使用 torch.utils.data.get_worker_info():
import torch
import math
from torch.utils.data import IterableDataset, DataLoader, get_worker_info
class ShardedLargeFileDataset(IterableDataset):
def __init__(self, file_path, processor_fn):
super().__init__()
self.file_path = file_path
self.processor_fn = processor_fn
# 如果需要分片,确定文件大小或行/记录数量
# self.num_records = self._get_num_records(file_path) # 辅助函数示例
def _get_records_iterator(self):
# 将此替换为遍历您的特定数据记录/文件的逻辑
with open(self.file_path, 'r') as f:
for line in f:
yield line # 生成原始记录
def __iter__(self):
worker_info = get_worker_info()
record_iterator = self._get_records_iterator()
if worker_info is None: # 单进程加载
worker_id = 0
num_workers = 1
else: # 多进程加载
worker_id = worker_info.id
num_workers = worker_info.num_workers
# 基础工作进程分片:每个工作进程处理每第 N 条记录
# 更复杂的分片可能涉及字节偏移量或文件拆分
sharded_iterator = (record for i, record in enumerate(record_iterator) if i % num_workers == worker_id)
# 在工作进程的迭代器链中应用处理
processed_iterator = map(self.processor_fn, sharded_iterator)
return processed_iterator
# 使用示例:
# processor = lambda line: torch.tensor([float(x) for x in line.strip().split()])
# dataset = ShardedLargeFileDataset('massive_dataset.txt', processor)
# loader = DataLoader(dataset, batch_size=64, num_workers=4)
#
# for batch in loader:
# # 训练步骤...
# pass
在此改进示例中,get_worker_info() 提供当前工作进程的 id 和总 num_workers 数量。然后代码过滤基础 record_iterator,使得工作进程 k 只处理 index % num_workers == k 的记录。这确保每个工作进程获得数据流中独有的、交错的子集。请注意,根据数据格式和存储方式,可能需要更复杂的分片(例如,将整个文件或字节范围分配给工作进程)。
打乱 IterableDataset 实例需要不同的策略,不同于映射式数据集。由于没有全局索引,您不能简单地打乱索引。常见方法包括:
DataLoader 不直接提供 IterableDataset 的此功能,但像 torchdata(PyTorch 领域库生态系统的一部分)这样的库提供了具有打乱功能的 DataPipes(例如,shuffle、sharding_filter)。IterableDataset 处理的文件列表。IterableDataset 流式传输数据块(例如文件路径或记录标识符),并在每个工作进程内使用映射式 Dataset 从该数据块加载和处理项目,从而允许在数据块内部进行打乱。选择取决于数据规模、所需的随机性程度以及您可以承受的开销。
无论您使用映射式还是可迭代数据集,优化数据加载管道对训练性能非常重要,特别是对于 I/O 可能成为瓶颈的大型数据集。
webdataset 这样的库旨在高效流式传输存储为 tar 归档文件的大型数据集,通常与 IterableDataset 配合使用。DataLoader 参数:
num_workers:将 num_workers 设置为 > 0 可启用数据加载的多进程处理。最优值取决于 CPU 核心数、批次大小、数据处理复杂度和 I/O 速度。一个常见的起始点是可用 CPU 核心数,但需要进行实验。工作进程过少会导致数据加载瓶颈;过多则会引起开销或耗尽系统资源。pin_memory=True:如果将数据加载到 GPU 上,将其设置为 True 会告诉 DataLoader 将获取的张量放入固定(页锁定)内存。这使得使用 tensor.to('cuda', non_blocking=True) 从 CPU 到 GPU 的异步数据传输更快。prefetch_factor (PyTorch 1.7+):控制每个工作进程预取多少批次。默认值 (2) 通常足够,但如果工作进程有时速度较慢,增加此值可能有助于隐藏数据加载延迟。以下图表说明了带有多个工作进程的 DataLoader 如何使用分片处理 IterableDataset:
IterableDataset和两个DataLoader工作进程的数据流。数据集向每个工作进程提供迭代器,并通过分片确保每个工作进程处理独特的数据,从而实现训练循环的并行数据加载。
通过使用 IterableDataset、细致的工作进程分片和优化数据加载管道,可以有效训练 PyTorch 模型,即便数据集远超系统内存容量,克服了大规模深度学习的一个重要障碍。这些技术通常与本章讨论的其他方法结合使用,例如梯度累积,以管理数据大小和计算限制。
这部分内容有帮助吗?
DataLoader配置以及多worker处理。torchdata库,该库扩展了PyTorch的数据加载能力,为可迭代数据集提供了DataPipes以实现高级流式传输和混洗。© 2026 ApX Machine Learning用心打造