Let's put the concepts from this chapter into practice. We'll walk through setting up and running a distributed training job using TensorFlow's tf.distribute.Strategy
. This hands-on exercise will solidify your understanding of how to scale training across multiple devices. For simplicity and broad applicability, we will focus primarily on tf.distribute.MirroredStrategy
, which handles synchronous training on multiple GPUs within a single machine. We will also briefly touch upon the setup required for multi-worker training.
Ensure you have TensorFlow installed. If you have multiple GPUs available on your system, TensorFlow should detect them automatically. You can verify this:
import tensorflow as tf
import os
import json
print("TensorFlow version:", tf.__version__)
gpus = tf.config.list_physical_devices('GPU')
if gpus:
print(f"Detected {len(gpus)} GPU(s):")
for i, gpu in enumerate(gpus):
print(f" GPU {i}: {gpu}")
else:
print("No GPU detected. MirroredStrategy requires at least one GPU,")
print("but works best with multiple GPUs. CPU execution will be slower.")
print("Consider using tf.distribute.OneDeviceStrategy('/cpu:0') for CPU practice.")
# Define a buffer size for shuffling the dataset
BUFFER_SIZE = 10000
# Define the global batch size. This will be split across replicas.
GLOBAL_BATCH_SIZE = 64 * len(gpus) if gpus else 64 # Example: 64 per replica
# Define the number of epochs for training
EPOCHS = 5
If you don't have multiple GPUs, MirroredStrategy
will still run on a single GPU or even the CPU (though without performance gains from distribution). The principles remain the same.
Distributed training requires careful handling of the dataset. Each replica (processing unit, typically a GPU) needs to process a distinct shard of the input data. The tf.data
API integrates seamlessly with tf.distribute.Strategy
. When you pass a tf.data.Dataset
to model.fit
within a strategy scope, TensorFlow automatically handles sharding the data across replicas.
Let's create a simple synthetic dataset using tf.data
:
# Create a simple synthetic dataset
def create_synthetic_dataset(num_samples=10000, num_features=10):
# Generate random features and binary labels
X = tf.random.normal(shape=(num_samples, num_features))
coeffs = tf.random.normal(shape=(num_features, 1))
logits = tf.matmul(X, coeffs) + tf.random.normal(shape=(num_samples, 1), stddev=0.1)
y = tf.cast(logits > 0, tf.int64)
return tf.data.Dataset.from_tensor_slices((X, y))
# Create the dataset
dataset = create_synthetic_dataset()
# Shuffle and batch the dataset
# Note: Batch size here is the GLOBAL batch size
dataset = dataset.shuffle(BUFFER_SIZE).batch(GLOBAL_BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
print(f"Dataset element spec: {dataset.element_spec}")
print(f"Global batch size: {GLOBAL_BATCH_SIZE}")
if gpus:
print(f"Per-replica batch size: {GLOBAL_BATCH_SIZE // len(gpus)}")
The GLOBAL_BATCH_SIZE
is the total batch size processed across all replicas in one step. MirroredStrategy
will automatically divide this global batch size by the number of replicas to determine the per-replica batch size. prefetch(tf.data.AUTOTUNE)
is used for performance, allowing data preprocessing to happen concurrently with model execution.
We'll use a standard Keras Sequential model for this example. The important part is where we define the model.
def build_model(num_features=10):
model = tf.keras.Sequential([
tf.keras.layers.Dense(16, activation='relu', input_shape=(num_features,)),
tf.keras.layers.Dense(8, activation='relu'),
tf.keras.layers.Dense(1, activation='sigmoid')
])
return model
Now, we choose our distribution strategy. For multiple GPUs on one machine, MirroredStrategy
is the typical choice.
# If GPUs are available, use MirroredStrategy. Otherwise, use OneDeviceStrategy for CPU.
if gpus:
strategy = tf.distribute.MirroredStrategy()
print(f"Using MirroredStrategy with {strategy.num_replicas_in_sync} replicas.")
else:
# Fallback for CPU or single GPU scenarios for demonstration purposes
strategy = tf.distribute.OneDeviceStrategy('/cpu:0') # Or '/gpu:0' if one GPU exists
print("Using OneDeviceStrategy (fallback).")
# Calculate the effective number of replicas
num_replicas = strategy.num_replicas_in_sync
This is a critical step. To ensure that model variables and the optimizer's state are created on the appropriate devices and mirrored correctly, their creation must happen inside the strategy.scope()
.
# Create the model and optimizer within the strategy scope
with strategy.scope():
# Model building
model = build_model()
# Optimizer definition
optimizer = tf.keras.optimizers.Adam()
# Loss function definition
loss_object = tf.keras.losses.BinaryCrossentropy(
from_logits=False,
reduction=tf.keras.losses.Reduction.NONE # Important for distributed training!
)
# Define metrics
train_accuracy = tf.keras.metrics.BinaryAccuracy(name='train_accuracy')
# Compile the model (optional but recommended for Keras fit/evaluate)
# Note: We use a dummy loss here because we will compute it manually below
# OR provide the loss_object directly if using model.fit
# model.compile(optimizer=optimizer, loss=loss_object, metrics=['accuracy'])
Notice the reduction=tf.keras.losses.Reduction.NONE
argument for the loss function. When training with tf.distribute.Strategy
, the loss should be calculated per example within a replica's batch. The strategy then handles aggregating these losses across all replicas correctly (usually by summing them and dividing by the global batch size). If you used the default SUM_OVER_BATCH_SIZE
reduction, the subsequent aggregation by the strategy would lead to incorrect scaling.
We can use the standard model.fit
API, which integrates seamlessly with tf.distribute.Strategy
. Alternatively, let's illustrate how a custom training step looks within a distributed context using strategy.run
and strategy.reduce
.
# Function to compute loss per replica
def compute_loss(labels, predictions):
per_example_loss = loss_object(labels, predictions)
# We compute the MEAN loss across the LOCAL batch size inside the replica.
# The strategy will handle the averaging across replicas later.
# Alternatively, you can sum here and divide by GLOBAL batch size after reduction.
return tf.nn.compute_average_loss(per_example_loss, global_batch_size=GLOBAL_BATCH_SIZE)
# Define the training step function
@tf.function
def distributed_train_step(inputs):
features, labels = inputs
def step_fn(features, labels):
with tf.GradientTape() as tape:
predictions = model(features, training=True)
loss = compute_loss(labels, predictions)
# Calculate gradients
gradients = tape.gradient(loss, model.trainable_variables)
# Apply gradients
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
# Update metrics
train_accuracy.update_state(labels, predictions)
return loss
# Run the step function on each replica
per_replica_losses = strategy.run(step_fn, args=(features, labels))
# Aggregate the results (e.g., loss) across replicas
# Use reduce with SUM and divide by num_replicas, or directly use MEAN.
mean_loss = strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
return mean_loss # Return the average loss across the global batch
Key points in distributed_train_step
:
@tf.function
decorator compiles the training step into a TensorFlow graph for performance.strategy.run(step_fn, ...)
executes the step_fn
on each replica, providing the appropriate data slice. step_fn
contains the core forward pass, loss calculation, and gradient computation/application for one replica.compute_loss
calculates the average loss over the local batch processed by the replica but scaled appropriately for the global batch size using tf.nn.compute_average_loss
.strategy.reduce(...)
aggregates values (like the loss) returned from each replica. Here, we sum the per-replica losses and return the total, which represents the average loss over the global batch because compute_loss
already incorporated the global batch size scaling.Now, we iterate through the epochs and steps, calling our distributed training step.
print("Starting distributed training...")
for epoch in range(EPOCHS):
total_loss = 0.0
num_batches = 0
# Reset metrics at the start of each epoch
train_accuracy.reset_state()
# Iterate over the distributed dataset
for batch_inputs in dataset:
# Run the distributed training step
batch_loss = distributed_train_step(batch_inputs)
total_loss += batch_loss
num_batches += 1
# Optional: Print progress within the epoch
# if num_batches % 50 == 0:
# print(f" Epoch {epoch+1}, Batch {num_batches}, Loss: {batch_loss.numpy():.4f}, Accuracy: {train_accuracy.result().numpy():.4f}")
# Calculate average loss over the epoch
epoch_loss = total_loss / num_batches
epoch_accuracy = train_accuracy.result()
print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.4f}")
print("Distributed training finished.")
model.compile
and model.fit
Alternatively, if you prefer the higher-level Keras API, you can use model.compile
and model.fit
. Ensure the model, optimizer, and metrics are created within strategy.scope()
. model.fit
will automatically handle the distribution when provided with a tf.data.Dataset
.
# Re-create model and optimizer within scope if not done already
with strategy.scope():
model_fit = build_model()
optimizer_fit = tf.keras.optimizers.Adam()
# Use standard reduction for compile/fit unless specific needs arise
loss_fit = tf.keras.losses.BinaryCrossentropy(from_logits=False)
model_fit.compile(optimizer=optimizer_fit, loss=loss_fit, metrics=['accuracy'])
print("\nStarting distributed training with model.fit...")
# model.fit handles dataset distribution automatically
history = model_fit.fit(dataset, epochs=EPOCHS, verbose=1)
print("Distributed training with model.fit finished.")
print("History:", history.history)
Using model.fit
is often simpler, but understanding the custom loop with strategy.run
provides deeper insight into how the distribution mechanism operates.
MultiWorkerMirroredStrategy
)While MirroredStrategy
handles multiple devices on one machine, MultiWorkerMirroredStrategy
scales across multiple machines (workers). Setting it up involves:
TF_CONFIG
Environment Variable: Each worker needs a TF_CONFIG
environment variable set (usually as a JSON string). This variable tells the worker about the entire cluster setup (addresses of all workers, the current worker's type and index).
# Example TF_CONFIG for worker 0 (on machine A)
{
"cluster": {
"worker": ["machine-a.example.com:20000", "machine-b.example.com:20000"]
},
"task": {"type": "worker", "index": 0}
}
# Example TF_CONFIG for worker 1 (on machine B)
{
"cluster": {
"worker": ["machine-a.example.com:20000", "machine-b.example.com:20000"]
},
"task": {"type": "worker", "index": 1}
}
Strategy Instantiation: You instantiate tf.distribute.MultiWorkerMirroredStrategy()
instead of MirroredStrategy
.
Code Execution: You run the same Python script on all worker machines. TensorFlow uses TF_CONFIG
to coordinate.
Dataset Sharding: Ensure your tf.data
pipeline correctly shards data across workers. Often, this involves setting tf.data.experimental.AutoShardPolicy.DATA
or tf.data.experimental.AutoShardPolicy.FILE
in the dataset options, or manually sharding based on TF_CONFIG
.
Checkpointing: Robust checkpointing (e.g., using tf.train.CheckpointManager
) is essential for fault tolerance in multi-worker setups. Checkpoints should typically be saved to a shared filesystem accessible by all workers.
Here's a conceptual diagram illustrating MultiWorkerMirroredStrategy
:
MultiWorkerMirroredStrategy setup involving multiple machines, each potentially having multiple GPUs. Communication and gradient synchronization are managed based on the
TF_CONFIG
variable set on each worker. The dataset is sharded across the workers.
Setting up MultiWorkerMirroredStrategy
requires infrastructure management (networking, environment variables) beyond the Python code itself, which is why we focused the hands-on part on MirroredStrategy
. However, the core TensorFlow code structure (using strategy.scope()
, adapting data pipelines) remains conceptually similar.
This practice exercise demonstrated the fundamental steps for implementing distributed training with tf.distribute.Strategy
. By defining your model, optimizer, and training steps within the strategy's scope and ensuring your data pipeline is correctly handled, you can leverage multiple devices to significantly accelerate your training workloads.
© 2025 ApX Machine Learning