When employing data parallelism strategies like MirroredStrategy
or MultiWorkerMirroredStrategy
, each processing unit (replica, typically a GPU or a worker machine) receives a distinct slice of the input data for each training step. Managing this data distribution effectively is fundamental to achieving good performance and correctness in distributed training. If the input pipeline cannot supply data fast enough to keep all replicas busy, the expensive accelerators will sit idle, negating the benefits of distribution.
tf.data
in Distributed SettingsTensorFlow's tf.data
API is the recommended way to build input pipelines for distributed training. It provides efficient, flexible abstractions for data loading, preprocessing, and iteration. Crucially, tf.data.Dataset
objects integrate seamlessly with tf.distribute.Strategy
. When you iterate over a dataset within a strategy's scope, the strategy automatically handles the distribution of data batches to the different replicas.
import tensorflow as tf
# Assume 'strategy' is an initialized tf.distribute.Strategy
# Assume 'global_batch_size' is the total batch size across all replicas
# Assume 'create_dataset()' returns a tf.data.Dataset instance
with strategy.scope():
# Create the dataset *outside* tf.function for better performance generally
dataset = create_dataset()
# Distribute the dataset. Each replica gets a portion of the global batch.
distributed_dataset = strategy.experimental_distribute_dataset(dataset)
# Model definition and optimizer creation would go here...
# ...
# Inside your training loop or tf.function:
@tf.function
def distributed_train_step(data_batch):
# 'data_batch' is automatically sharded across replicas.
# Each replica receives its portion of the global batch.
def replica_fn(inputs):
# Model forward pass, loss calculation, gradients...
# ...
pass # Replace with actual replica training logic
# Run the computation on each replica
strategy.run(replica_fn, args=(data_batch,))
# Iterate over the distributed dataset
for batch in distributed_dataset:
distributed_train_step(batch)
# ... rest of the training loop
In this structure, strategy.experimental_distribute_dataset
takes the full dataset and returns a DistributedDataset
object. When you iterate over distributed_dataset
, it yields batches split appropriately for each replica participating in the strategy.
By default, when you use tf.data.Dataset
with a tf.distribute.Strategy
, TensorFlow attempts to automatically shard the dataset across the participating workers or replicas. Sharding means dividing the dataset so that each worker processes only a fraction of the total data, preventing redundant work and ensuring all data is seen approximately once per epoch (depending on the dataset size and configuration).
The default policy, tf.data.experimental.AutoShardPolicy.AUTO
, usually tries to shard by file if the dataset originates from file sources (like TFRecords read via tf.data.TFRecordDataset
). If file-based sharding isn't feasible, it falls back to sharding by data, where each worker reads the full dataset but dynamically skips elements to process only its assigned shard. While convenient, sharding by data can be inefficient due to redundant reads. File-based sharding is generally preferred when possible.
Representation of automatic file-based sharding where a dataset composed of four files is distributed across four workers. Each worker processes one file.
You can explicitly control the sharding policy:
options = tf.data.Options()
# Explicitly choose FILE sharding, or DATA, OFF
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.FILE
dataset = dataset.with_options(options)
# Then distribute as before
distributed_dataset = strategy.experimental_distribute_dataset(dataset)
Setting the policy to tf.data.experimental.AutoShardPolicy.OFF
disables automatic sharding, requiring you to implement manual sharding if needed.
Sometimes automatic sharding isn't sufficient or applicable. You might need manual control if:
tf.data
(e.g., a custom data generator, reading from a database that doesn't support offset/limit efficiently).tf.data.Dataset
.In these cases, you can manually shard the data based on the worker's context. The tf.distribute.InputContext
object provides information about the current worker, including its ID and the total number of workers.
# Example: Manually sharding a list of filenames
import os
# Assume 'strategy' is initialized (e.g., MultiWorkerMirroredStrategy)
# Assume 'all_filenames' is a list of all data files
input_options = tf.distribute.InputOptions(
experimental_fetch_to_device=False,
experimental_replication_mode=tf.distribute.InputReplicationMode.PER_WORKER
)
def dataset_fn(input_context):
# Get worker info from the context
worker_id = input_context.input_pipeline_id
num_workers = input_context.num_input_pipelines
# Simple sharding: Each worker gets files based on its ID
worker_filenames = [
f for i, f in enumerate(all_filenames)
if i % num_workers == worker_id
]
# Create a dataset only from the assigned files for this worker
worker_dataset = tf.data.TFRecordDataset(worker_filenames)
# Apply preprocessing, batching etc. specific to this worker's dataset
# Make sure batch_size here is the PER-REPLICA batch size
per_replica_batch_size = global_batch_size // strategy.num_replicas_in_sync
worker_dataset = worker_dataset.batch(per_replica_batch_size)
worker_dataset = worker_dataset.prefetch(tf.data.AUTOTUNE)
return worker_dataset
# Create the distributed dataset using the input function and options
distributed_dataset = strategy.distribute_datasets_from_function(
dataset_fn,
input_options
)
# Iterate and train as before...
Here, distribute_datasets_from_function
calls dataset_fn
once for each worker, passing an InputContext
. The function uses the context to determine which subset of data files that specific worker should handle. Note the InputOptions
specifying PER_WORKER
replication mode, indicating that dataset_fn
defines the dataset for an entire worker.
An inefficient input pipeline is often a primary bottleneck in distributed training. If data loading and preprocessing cannot keep pace with the computation speed of multiple accelerators, training time will not decrease proportionally to the number of devices added. Use the following tf.data
optimizations aggressively:
.prefetch(tf.data.AUTOTUNE)
): Always add this as the last step in your dataset pipeline. It allows the CPU to prepare the next batch(es) of data while the accelerators are busy processing the current batch, overlapping data preparation and model execution..map(..., num_parallel_calls=tf.data.AUTOTUNE)
): If your preprocessing function (map
) is CPU-intensive, run multiple calls in parallel to utilize multiple CPU cores. tf.data.AUTOTUNE
lets TensorFlow dynamically tune the level of parallelism..cache()
): If your entire dataset fits into memory (or local disk if you provide a filename to .cache()
), caching transforms the dataset after the first epoch. Subsequent epochs read directly from the cache, potentially speeding up loading significantly, especially if preprocessing is expensive or the source data is remote. Use it after CPU-intensive preprocessing but before operations like shuffling or batching that should occur each epoch..interleave(..., num_parallel_calls=tf.data.AUTOTUNE)
): When reading from multiple files (e.g., TFRecord shards), interleaving reads blocks from multiple files concurrently, improving throughput compared to reading files sequentially, especially with remote storage.Monitor your input pipeline performance using the TensorFlow Profiler to identify and address bottlenecks. Ensure your CPU utilization is high during training, indicating the pipeline is working effectively.
When working with datasets stored across many files (e.g., TFRecords), consider these practices:
tf.data.Dataset.list_files(..., shuffle=True)
, be aware that this shuffles the list of files. For effective shuffling across epochs, especially with file-based sharding, it's often better to list files, shuffle the resulting dataset of filenames with reshuffle_each_iteration=True
, and then interleave reads from these shuffled filenames.# Example: Recommended pattern for sharded TFRecords
num_workers = strategy.num_replicas_in_sync # Simplified for single-worker multi-GPU
worker_id = 0 # Simplified for single-worker multi-GPU
if hasattr(strategy.extended, '_input_workers'): # Check for MultiWorker
num_workers = strategy.extended._input_workers.num_workers
worker_id = strategy.extended._input_workers.worker_index
file_pattern = "/path/to/tfrecords/train-*.tfrecord"
per_replica_batch_size = 64
global_batch_size = per_replica_batch_size * strategy.num_replicas_in_sync
files = tf.data.Dataset.list_files(file_pattern, shuffle=False) # List files deterministically
# Manually shard the list of files across workers
files = files.shard(num_workers, worker_id)
# Shuffle files within the worker's shard each epoch
files = files.shuffle(buffer_size=tf.data.AUTOTUNE, reshuffle_each_iteration=True)
# Interleave reads from multiple files concurrently
dataset = files.interleave(
lambda filepath: tf.data.TFRecordDataset(filepath),
cycle_length=tf.data.AUTOTUNE,
num_parallel_calls=tf.data.AUTOTUNE,
deterministic=False # Allow non-determinism for performance
)
# Further shuffle records, map processing, batch, and prefetch
dataset = dataset.shuffle(buffer_size=10000) # Shuffle records
dataset = dataset.map(decode_and_preprocess_fn, num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.batch(per_replica_batch_size, drop_remainder=True) # Use per-replica batch size
dataset = dataset.prefetch(tf.data.AUTOTUNE)
# Distribute the dataset constructed per-worker (if using distribute_datasets_from_function)
# OR let auto-sharding handle it if constructed globally and using experimental_distribute_dataset
# distributed_dataset = strategy.experimental_distribute_dataset(dataset) # If built globally
This example demonstrates manual file sharding combined with best practices like shuffling files per worker and interleaving reads.
Incorrect data handling can lead to subtle bugs or degraded model performance:
MirroredStrategy
, MultiWorkerMirroredStrategy
), all replicas must process the same number of examples per step for gradient aggregation to work correctly. Use drop_remainder=True
when calling .batch()
on your dataset. This ensures that the last partial batch, if any, is dropped, maintaining consistent batch sizes across all steps and replicas. While it means discarding a small amount of data, it's usually necessary for synchronous distributed training correctness.tf.data
pipeline (e.g., custom counters) as state management can become complex in distributed settings. Prefer stateless transformations where possible.By carefully constructing and optimizing your tf.data
input pipelines, considering both automatic and manual sharding techniques, and applying performance best practices like prefetching and parallel processing, you can ensure that data loading does not become a hindrance when scaling your TensorFlow training jobs.
© 2025 ApX Machine Learning