Setting up and running a distributed training job using TensorFlow's tf.distribute.Strategy is demonstrated. This hands-on exercise will solidify your understanding of how to scale training across multiple devices. For simplicity and broad applicability, the primary focus is on tf.distribute.MirroredStrategy, which handles synchronous training on multiple GPUs within a single machine. Multi-worker training setup will also be briefly discussed.PrerequisitesEnsure you have TensorFlow installed. If you have multiple GPUs available on your system, TensorFlow should detect them automatically. You can verify this:import tensorflow as tf import os import json print("TensorFlow version:", tf.__version__) gpus = tf.config.list_physical_devices('GPU') if gpus: print(f"Detected {len(gpus)} GPU(s):") for i, gpu in enumerate(gpus): print(f" GPU {i}: {gpu}") else: print("No GPU detected. MirroredStrategy requires at least one GPU,") print("but works best with multiple GPUs. CPU execution will be slower.") print("Consider using tf.distribute.OneDeviceStrategy('/cpu:0') for CPU practice.") # Define a buffer size for shuffling the dataset BUFFER_SIZE = 10000 # Define the global batch size. This will be split across replicas. GLOBAL_BATCH_SIZE = 64 * len(gpus) if gpus else 64 # Example: 64 per replica # Define the number of epochs for training EPOCHS = 5If you don't have multiple GPUs, MirroredStrategy will still run on a single GPU or even the CPU (though without performance gains from distribution). The principles remain the same.1. Prepare the DatasetDistributed training requires careful handling of the dataset. Each replica (processing unit, typically a GPU) needs to process a distinct shard of the input data. The tf.data API integrates with tf.distribute.Strategy. When you pass a tf.data.Dataset to model.fit within a strategy scope, TensorFlow automatically handles sharding the data across replicas.Let's create a simple synthetic dataset using tf.data:# Create a simple synthetic dataset def create_synthetic_dataset(num_samples=10000, num_features=10): # Generate random features and binary labels X = tf.random.normal(shape=(num_samples, num_features)) coeffs = tf.random.normal(shape=(num_features, 1)) logits = tf.matmul(X, coeffs) + tf.random.normal(shape=(num_samples, 1), stddev=0.1) y = tf.cast(logits > 0, tf.int64) return tf.data.Dataset.from_tensor_slices((X, y)) # Create the dataset dataset = create_synthetic_dataset() # Shuffle and batch the dataset # Note: Batch size here is the GLOBAL batch size dataset = dataset.shuffle(BUFFER_SIZE).batch(GLOBAL_BATCH_SIZE).prefetch(tf.data.AUTOTUNE) print(f"Dataset element spec: {dataset.element_spec}") print(f"Global batch size: {GLOBAL_BATCH_SIZE}") if gpus: print(f"Per-replica batch size: {GLOBAL_BATCH_SIZE // len(gpus)}")The GLOBAL_BATCH_SIZE is the total batch size processed across all replicas in one step. MirroredStrategy will automatically divide this global batch size by the number of replicas to determine the per-replica batch size. prefetch(tf.data.AUTOTUNE) is used for performance, allowing data preprocessing to happen concurrently with model execution.2. Define the ModelWe'll use a standard Keras Sequential model for this example. The important part is where we define the model.def build_model(num_features=10): model = tf.keras.Sequential([ tf.keras.layers.Dense(16, activation='relu', input_shape=(num_features,)), tf.keras.layers.Dense(8, activation='relu'), tf.keras.layers.Dense(1, activation='sigmoid') ]) return model3. Select and Initialize the StrategyNow, we choose our distribution strategy. For multiple GPUs on one machine, MirroredStrategy is the typical choice.# If GPUs are available, use MirroredStrategy. Otherwise, use OneDeviceStrategy for CPU. if gpus: strategy = tf.distribute.MirroredStrategy() print(f"Using MirroredStrategy with {strategy.num_replicas_in_sync} replicas.") else: # Fallback for CPU or single GPU scenarios for demonstration purposes strategy = tf.distribute.OneDeviceStrategy('/cpu:0') # Or '/gpu:0' if one GPU exists print("Using OneDeviceStrategy (fallback).") # Calculate the effective number of replicas num_replicas = strategy.num_replicas_in_sync4. Create Model and Optimizer within Strategy ScopeThis is a critical step. To ensure that model variables and the optimizer's state are created on the appropriate devices and mirrored correctly, their creation must happen inside the strategy.scope().# Create the model and optimizer within the strategy scope with strategy.scope(): # Model building model = build_model() # Optimizer definition optimizer = tf.keras.optimizers.Adam() # Loss function definition loss_object = tf.keras.losses.BinaryCrossentropy( from_logits=False, reduction=tf.keras.losses.Reduction.NONE # Important for distributed training! ) # Define metrics train_accuracy = tf.keras.metrics.BinaryAccuracy(name='train_accuracy') # Compile the model (optional but recommended for Keras fit/evaluate) # Note: We use a dummy loss here because we will compute it manually below # OR provide the loss_object directly if using model.fit # model.compile(optimizer=optimizer, loss=loss_object, metrics=['accuracy'])Notice the reduction=tf.keras.losses.Reduction.NONE argument for the loss function. When training with tf.distribute.Strategy, the loss should be calculated per example within a replica's batch. The strategy then handles aggregating these losses across all replicas correctly (usually by summing them and dividing by the global batch size). If you used the default SUM_OVER_BATCH_SIZE reduction, the subsequent aggregation by the strategy would lead to incorrect scaling.5. Implement the Training StepWe can use the standard model.fit API, which integrates with tf.distribute.Strategy. Alternatively, let's illustrate how a custom training step looks within a distributed context using strategy.run and strategy.reduce.# Function to compute loss per replica def compute_loss(labels, predictions): per_example_loss = loss_object(labels, predictions) # We compute the MEAN loss across the LOCAL batch size inside the replica. # The strategy will handle the averaging across replicas later. # Alternatively, you can sum here and divide by GLOBAL batch size after reduction. return tf.nn.compute_average_loss(per_example_loss, global_batch_size=GLOBAL_BATCH_SIZE) # Define the training step function @tf.function def distributed_train_step(inputs): features, labels = inputs def step_fn(features, labels): with tf.GradientTape() as tape: predictions = model(features, training=True) loss = compute_loss(labels, predictions) # Calculate gradients gradients = tape.gradient(loss, model.trainable_variables) # Apply gradients optimizer.apply_gradients(zip(gradients, model.trainable_variables)) # Update metrics train_accuracy.update_state(labels, predictions) return loss # Run the step function on each replica per_replica_losses = strategy.run(step_fn, args=(features, labels)) # Aggregate the results (e.g., loss) across replicas # Use reduce with SUM and divide by num_replicas, or directly use MEAN. mean_loss = strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None) return mean_loss # Return the average loss across the global batchImportant Points in distributed_train_step:The @tf.function decorator compiles the training step into a TensorFlow graph for performance.strategy.run(step_fn, ...) executes the step_fn on each replica, providing the appropriate data slice. step_fn contains the core forward pass, loss calculation, and gradient computation/application for one replica.compute_loss calculates the average loss over the local batch processed by the replica but scaled appropriately for the global batch size using tf.nn.compute_average_loss.strategy.reduce(...) aggregates values (like the loss) returned from each replica. Here, we sum the per-replica losses and return the total, which represents the average loss over the global batch because compute_loss already incorporated the global batch size scaling.6. Run the Training LoopNow, we iterate through the epochs and steps, calling our distributed training step.print("Starting distributed training...") for epoch in range(EPOCHS): total_loss = 0.0 num_batches = 0 # Reset metrics at the start of each epoch train_accuracy.reset_state() # Iterate over the distributed dataset for batch_inputs in dataset: # Run the distributed training step batch_loss = distributed_train_step(batch_inputs) total_loss += batch_loss num_batches += 1 # Optional: Print progress within the epoch # if num_batches % 50 == 0: # print(f" Epoch {epoch+1}, Batch {num_batches}, Loss: {batch_loss.numpy():.4f}, Accuracy: {train_accuracy.result().numpy():.4f}") # Calculate average loss over the epoch epoch_loss = total_loss / num_batches epoch_accuracy = train_accuracy.result() print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.4f}") print("Distributed training finished.")Using model.compile and model.fitAlternatively, if you prefer the higher-level Keras API, you can use model.compile and model.fit. Ensure the model, optimizer, and metrics are created within strategy.scope(). model.fit will automatically handle the distribution when provided with a tf.data.Dataset.# Re-create model and optimizer within scope if not done already with strategy.scope(): model_fit = build_model() optimizer_fit = tf.keras.optimizers.Adam() # Use standard reduction for compile/fit unless specific needs arise loss_fit = tf.keras.losses.BinaryCrossentropy(from_logits=False) model_fit.compile(optimizer=optimizer_fit, loss=loss_fit, metrics=['accuracy']) print("\nStarting distributed training with model.fit...") # model.fit handles dataset distribution automatically history = model_fit.fit(dataset, epochs=EPOCHS, verbose=1) print("Distributed training with model.fit finished.") print("History:", history.history)Using model.fit is often simpler, but understanding the custom loop with strategy.run provides deeper insight into how the distribution mechanism operates.Approaches for Multi-Worker Training (MultiWorkerMirroredStrategy)While MirroredStrategy handles multiple devices on one machine, MultiWorkerMirroredStrategy scales across multiple machines (workers). Setting it up involves:TF_CONFIG Environment Variable: Each worker needs a TF_CONFIG environment variable set (usually as a JSON string). This variable tells the worker about the entire cluster setup (addresses of all workers, the current worker's type and index).# Example TF_CONFIG for worker 0 (on machine A) { "cluster": { "worker": ["machine-a.example.com:20000", "machine-b.example.com:20000"] }, "task": {"type": "worker", "index": 0} } # Example TF_CONFIG for worker 1 (on machine B) { "cluster": { "worker": ["machine-a.example.com:20000", "machine-b.example.com:20000"] }, "task": {"type": "worker", "index": 1} }Strategy Instantiation: You instantiate tf.distribute.MultiWorkerMirroredStrategy() instead of MirroredStrategy.Code Execution: You run the same Python script on all worker machines. TensorFlow uses TF_CONFIG to coordinate.Dataset Sharding: Ensure your tf.data pipeline correctly shards data across workers. Often, this involves setting tf.data.experimental.AutoShardPolicy.DATA or tf.data.experimental.AutoShardPolicy.FILE in the dataset options, or manually sharding based on TF_CONFIG.Checkpointing: Checkpointing (e.g., using tf.train.CheckpointManager) is essential for fault tolerance in multi-worker setups. Checkpoints should typically be saved to a shared filesystem accessible by all workers.Here's a diagram illustrating MultiWorkerMirroredStrategy:digraph G { rankdir=LR; node [shape=record, style=filled, color="#a5d8ff", fillcolor="#e9ecef"]; subgraph cluster_0 { label = "Worker 0 (Machine A)"; style=filled; color="#dee2e6"; node [color="#1c7ed6"]; worker0 [label="{ TF_CONFIG | { type: worker, index: 0 } }"]; subgraph cluster_0_gpus { label = "GPUs"; style=filled; color="#ced4da"; node [shape=box, style=filled, color="#4263eb", fillcolor="#bac8ff"]; gpu0_0 [label="GPU 0"]; gpu0_1 [label="GPU 1"]; } worker0 -> gpu0_0 [style=dotted]; worker0 -> gpu0_1 [style=dotted]; } subgraph cluster_1 { label = "Worker 1 (Machine B)"; style=filled; color="#dee2e6"; node [color="#1c7ed6"]; worker1 [label="{ TF_CONFIG | { type: worker, index: 1 } }"]; subgraph cluster_1_gpus { label = "GPUs"; style=filled; color="#ced4da"; node [shape=box, style=filled, color="#4263eb", fillcolor="#bac8ff"]; gpu1_0 [label="GPU 0"]; gpu1_1 [label="GPU 1"]; } worker1 -> gpu1_0 [style=dotted]; worker1 -> gpu1_1 [style=dotted]; } coord [label="Coordinator\n(Implicit, via TF_CONFIG)", shape= Mdiamond, style=filled, color="#7048e8", fillcolor="#d0bfff"]; worker0 -> coord [label=" Sync Gradients " dir=both color="#7048e8"]; worker1 -> coord [label=" Sync Gradients " dir=both color="#7048e8"]; dataset [label="Sharded Dataset", shape=cylinder, style=filled, color="#0ca678", fillcolor="#96f2d7"]; dataset -> worker0 [label="Shard 0"]; dataset -> worker1 [label="Shard 1"]; }MultiWorkerMirroredStrategy setup involving multiple machines, each potentially having multiple GPUs. Communication and gradient synchronization are managed based on the TF_CONFIG variable set on each worker. The dataset is sharded across the workers.Setting up MultiWorkerMirroredStrategy requires infrastructure management (networking, environment variables) in addition to the Python code itself, which is why we focused the hands-on part on MirroredStrategy. However, the core TensorFlow code structure (using strategy.scope(), adapting data pipelines) remains similar.This practice exercise demonstrated the fundamental steps for implementing distributed training with tf.distribute.Strategy. By defining your model, optimizer, and training steps within the strategy's scope and ensuring your data pipeline is correctly handled, you can leverage multiple devices to significantly accelerate your training workloads.