趋近智
在使用数据并行策略(如MirroredStrategy或MultiWorkerMirroredStrategy)时,每个处理单元(副本,通常是GPU或工作机器)在每个训练步骤都会收到输入数据的一个不同片段。有效管理这种数据分发对于在分布式训练中获得良好性能和正确性非常重要。如果输入管道无法足够快地提供数据以使所有副本保持忙碌,昂贵的加速器就会闲置,从而抵消了分布式训练的好处。
tf.data 在分布式环境中的作用TensorFlow 的 tf.data API 是构建分布式训练输入管道的推荐方式。它为数据加载、预处理和迭代提供了高效、灵活的抽象。重要的是,tf.data.Dataset 对象与 tf.distribute.Strategy 结合。当你在策略范围内迭代数据集时,该策略会自动处理数据批次到不同副本的分发。
import tensorflow as tf
# 假设 'strategy' 是一个已初始化的 tf.distribute.Strategy
# 假设 'global_batch_size' 是所有副本的总批次大小
# 假设 'create_dataset()' 返回一个 tf.data.Dataset 实例
with strategy.scope():
# 通常为了更好的性能,在 tf.function *之外* 创建数据集
dataset = create_dataset()
# 分发数据集。每个副本获得全局批次的一部分。
distributed_dataset = strategy.experimental_distribute_dataset(dataset)
# 模型定义和优化器创建会放在这里...
# ...
# 在你的训练循环或 tf.function 内部:
@tf.function
def distributed_train_step(data_batch):
# 'data_batch' 会在副本之间自动分片。
# 每个副本接收其全局批次的一部分。
def replica_fn(inputs):
# 模型前向传播、损失计算、梯度...
# ...
pass # 用实际的副本训练逻辑替换
# 在每个副本上运行计算
strategy.run(replica_fn, args=(data_batch,))
# 迭代分布式数据集
for batch in distributed_dataset:
distributed_train_step(batch)
# ... 训练循环的其余部分
在这个结构中,strategy.experimental_distribute_dataset 接受完整数据集并返回一个 DistributedDataset 对象。当你迭代 distributed_dataset 时,它会产生为参与策略的每个副本适当分割的批次。
默认情况下,当你将 tf.data.Dataset 与 tf.distribute.Strategy 一起使用时,TensorFlow 会尝试自动将数据集分片到参与的工作进程或副本上。分片意味着将数据集分割,以便每个工作进程只处理总数据的一小部分,从而防止重复工作并确保每个周期所有数据大约只被处理一次(取决于数据集大小和配置)。
默认策略 tf.data.experimental.AutoShardPolicy.AUTO 通常会尝试按文件分片,如果数据集源自文件(例如通过 tf.data.TFRecordDataset 读取的TFRecords)。如果基于文件的分片不可行,它会退回到按数据分片,即每个工作进程读取完整数据集,但动态跳过元素以只处理其分配的分片。虽然方便,但按数据分片可能由于重复读取而效率低下。在可能的情况下,通常更推荐基于文件的分片。
自动文件分片的示意图,其中包含四个文件的数据集分布到四个工作进程。每个工作进程处理一个文件。
你可以明确控制分片策略:
options = tf.data.Options()
# 明确选择按文件 (FILE) 分片、按数据 (DATA) 分片或关闭 (OFF)
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.FILE
dataset = dataset.with_options(options)
# 然后像之前一样分发
distributed_dataset = strategy.experimental_distribute_dataset(dataset)
将策略设置为 tf.data.experimental.AutoShardPolicy.OFF 会禁用自动分片,如果需要,这要求你实现手动分片。
有时自动分片不足或不适用。在以下情况下,你可能需要手动控制:
tf.data 分片(例如,自定义数据生成器,或从不支持高效偏移/限制的数据库读取)。tf.data.Dataset。在这些情况下,你可以根据工作进程的上下文 (context)手动分片数据。tf.distribute.InputContext 对象提供有关当前工作进程的信息,包括其ID和工作进程总数。
# 示例:手动分片文件名列表
import os
# 假设 'strategy' 已初始化(例如,MultiWorkerMirroredStrategy)
# 假设 'all_filenames' 是所有数据文件的列表
input_options = tf.distribute.InputOptions(
experimental_fetch_to_device=False,
experimental_replication_mode=tf.distribute.InputReplicationMode.PER_WORKER
)
def dataset_fn(input_context):
# 从上下文中获取工作进程信息
worker_id = input_context.input_pipeline_id
num_workers = input_context.num_input_pipelines
# 简单分片:每个工作进程根据其ID获取文件
worker_filenames = [
f for i, f in enumerate(all_filenames)
if i % num_workers == worker_id
]
# 仅为该工作进程从分配的文件创建数据集
worker_dataset = tf.data.TFRecordDataset(worker_filenames)
# 应用特定于此工作进程数据集的预处理、批处理等
# 确保这里的 batch_size 是每个副本的批次大小
per_replica_batch_size = global_batch_size // strategy.num_replicas_in_sync
worker_dataset = worker_dataset.batch(per_replica_batch_size)
worker_dataset = worker_dataset.prefetch(tf.data.AUTOTUNE)
return worker_dataset
# 使用输入函数和选项创建分布式数据集
distributed_dataset = strategy.distribute_datasets_from_function(
dataset_fn,
input_options
)
# 像之前一样迭代和训练...
这里,distribute_datasets_from_function 会为每个工作进程调用 dataset_fn 一次,并传入一个 InputContext。该函数使用上下文来决定该特定工作进程应处理的数据文件子集。请注意 InputOptions 中指定的 PER_WORKER 复制模式,这表明 dataset_fn 定义了整个工作进程的数据集。
低效的输入管道往往是分布式训练中的主要瓶颈。如果数据加载和预处理无法跟上多个加速器的计算速度,训练时间将不会与添加的设备数量成比例地减少。请大力使用以下 tf.data 优化手段:
.prefetch(tf.data.AUTOTUNE)): 务必将其作为数据集管道的最后一步。它允许CPU在加速器忙于处理当前批次时准备下一个批次的数据,从而重叠数据准备和模型执行。.map(..., num_parallel_calls=tf.data.AUTOTUNE)): 如果你的预处理函数(map)是CPU密集型的,请并行运行多个调用以使用多个CPU核心。tf.data.AUTOTUNE 让 TensorFlow 动态调整并行级别。.cache()): 如果你的整个数据集可以放入内存(或者如果你向 .cache() 提供文件名,可以放入本地磁盘),缓存会在第一个周期之后转换数据集。随后的周期直接从缓存读取,这可能会显著加快加载速度,特别是当预处理开销大或源数据是远程的时。在 CPU 密集型预处理之后但在每个周期都应发生的操作(如混洗或批处理)之前使用它。.interleave(..., num_parallel_calls=tf.data.AUTOTUNE)): 当从多个文件(例如 TFRecord 分片)读取时,交错会并发地从多个文件读取数据块,与顺序读取文件相比,这可以提高吞吐量 (throughput),尤其是在远程存储的情况下。使用 TensorFlow Profiler 监控你的输入管道性能,以识别并解决瓶颈。确保训练期间的CPU利用率高,这表明管道工作高效。
当处理存储在多个文件中的数据集(例如 TFRecords)时,请考虑以下做法:
tf.data.Dataset.list_files(..., shuffle=True),请注意这混洗的是文件列表。为了在各个周期之间进行有效混洗,特别是在基于文件的分片中,通常更好的做法是列出文件,然后使用 reshuffle_each_iteration=True 混洗生成的文件名数据集,接着交错读取这些混洗过的文件。# 示例:分片 TFRecords 的推荐模式
num_workers = strategy.num_replicas_in_sync # 为单工作进程多GPU简化
worker_id = 0 # 为单工作进程多GPU简化
if hasattr(strategy.extended, '_input_workers'): # 检查多工作进程情况
num_workers = strategy.extended._input_workers.num_workers
worker_id = strategy.extended._input_workers.worker_index
file_pattern = "/path/to/tfrecords/train-*.tfrecord"
per_replica_batch_size = 64
global_batch_size = per_replica_batch_size * strategy.num_replicas_in_sync
files = tf.data.Dataset.list_files(file_pattern, shuffle=False) # 确定性地列出文件
# 手动将文件列表分片到工作进程
files = files.shard(num_workers, worker_id)
# 每个周期在工作进程的分片内混洗文件
files = files.shuffle(buffer_size=tf.data.AUTOTUNE, reshuffle_each_iteration=True)
# 并发地交错读取多个文件
dataset = files.interleave(
lambda filepath: tf.data.TFRecordDataset(filepath),
cycle_length=tf.data.AUTOTUNE,
num_parallel_calls=tf.data.AUTOTUNE,
deterministic=False # 允许非确定性以提高性能
)
# 进一步混洗记录、映射处理、批处理和预取
dataset = dataset.shuffle(buffer_size=10000) # 混洗记录
dataset = dataset.map(decode_and_preprocess_fn, num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.batch(per_replica_batch_size, drop_remainder=True) # 使用每个副本的批次大小
dataset = dataset.prefetch(tf.data.AUTOTUNE)
# 分发每个工作进程构建的数据集(如果使用 distribute_datasets_from_function)
# 或者如果全局构建并使用 experimental_distribute_dataset,则让自动分片处理
# distributed_dataset = strategy.experimental_distribute_dataset(dataset) # 如果是全局构建的
此示例演示了手动文件分片,并结合了每个工作进程混洗文件和交错读取等最佳实践。
不正确的数据处理可能导致隐蔽的错误或模型性能下降:
MirroredStrategy、MultiWorkerMirroredStrategy)中,所有副本每步必须处理相同数量的样本,以使梯度聚合正常工作。在你的数据集上调用 .batch() 时使用 drop_remainder=True。这确保了如果存在最后一个部分批次,它会被丢弃,从而在所有步骤和副本中保持批次大小一致。虽然这意味着丢弃少量数据,但这通常是同步分布式训练正确性所必需的。tf.data 管道中(例如,自定义计数器),请谨慎使用有状态操作,因为在分布式环境中状态管理会变得复杂。在可能的情况下,优先选择无状态转换。通过精心构建和优化你的 tf.data 输入管道,兼顾自动和手动分片技术,并应用预取和并行处理等性能最佳实践,你可以确保在扩展 TensorFlow 训练任务时,数据加载不会成为阻碍。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造