A basic data-parallel training loop involves simulating multiple devices (like GPUs or TPU cores) and distributing a simple machine learning task across them using pmap. The goal is to train a model where each device processes a portion of the data batch, calculates gradients locally, and then collaborates to compute the average gradient for a synchronous parameter update.This exercise demonstrates the core workflow of Single-Program Multiple-Data (SPMD) execution with pmap: write the code for one device, and pmap handles replicating the execution and coordinating necessary communication.Setup and DependenciesFirst, let's import JAX, NumPy, and check the number of available devices. For this example, we'll simulate multiple devices even if you're running on a CPU, but the code works on multi-GPU or TPU setups.import jax import jax.numpy as jnp import numpy as np from jax import pmap, grad, value_and_grad, jit from jax.lax import pmean # Check available devices num_devices = jax.local_device_count() print(f"Number of available devices: {num_devices}") # If running on CPU, JAX can simulate multiple devices for pmap testing # Uncomment the following line to simulate, e.g., 4 devices on a CPU # jax.config.update('jax_platforms', 'cpu') # Force CPU usage if needed # jax.config.update("jax_cpu_device_count", 4) # num_devices = jax.local_device_count() # Update num_devices if simulating # print(f"Number of simulated devices: {num_devices}") # Ensure we have at least 2 devices for a meaningful example if num_devices < 2: print("Warning: This example is best run with multiple devices (real or simulated).") # You can still run it, but pmap won't provide parallelism benefits. # Generate a PRNG key key = jax.random.PRNGKey(0)Define a Simple Model and Loss FunctionWe'll use a basic linear regression model for simplicity. The task is to find weights w and bias b such that $y \approx Xw + b$.def linear_model(params, x): """A simple linear model prediction function.""" w, b = params return jnp.dot(x, w) + b def mean_squared_error(params, x_batched, y_batched): """Calculates the mean squared error loss.""" predictions = linear_model(params, x_batched) error = predictions - y_batched loss = jnp.mean(error**2) return lossPrepare Sample Data and Shard ItWe need data that can be split across our devices. The standard practice for data parallelism with pmap is to ensure the first dimension of your input arrays corresponds to the number of devices. Each slice along this dimension (data[i]) is sent to the i-th device.# Generate synthetic data feature_dim = 5 num_samples = 100 * num_devices # Ensure total samples are divisible by num_devices key, w_key, b_key, x_key, noise_key = jax.random.split(key, 5) # True parameters (we want the model to learn these) true_w = jax.random.normal(w_key, (feature_dim,)) true_b = jax.random.normal(b_key, ()) # Generate features X and targets y X = jax.random.normal(x_key, (num_samples, feature_dim)) noise = jax.random.normal(noise_key, (num_samples,)) * 0.1 y = jnp.dot(X, true_w) + true_b + noise # Reshape data for pmap: add a leading dimension for devices # Each device will get batch_size_per_device = num_samples // num_devices samples batch_size_per_device = num_samples // num_devices sharded_X = X.reshape((num_devices, batch_size_per_device, feature_dim)) sharded_y = y.reshape((num_devices, batch_size_per_device)) print(f"Total samples: {num_samples}") print(f"Samples per device: {batch_size_per_device}") print(f"Shape of sharded X: {sharded_X.shape}") # Should be (num_devices, batch_size_per_device, feature_dim) print(f"Shape of sharded y: {sharded_y.shape}") # Should be (num_devices, batch_size_per_device)Define the Distributed Training StepThis is the core of the example. We'll create a function designed to run on a single device, calculating the loss and gradients for its local shard of data. Then, we'll use pmap to run this function in parallel across all devices. Crucially, inside the pmap'd function, we'll use jax.lax.pmean to average the gradients computed independently on each device before performing the parameter update.# Define the function to calculate loss and gradients on one device compute_loss_and_grads = value_and_grad(mean_squared_error) # Define the training step function that will be pmapped # It takes the current parameters and a shard of data for one device def distributed_train_step(params, x_shard, y_shard, learning_rate): """Performs one training step on sharded data across devices.""" # 1. Calculate loss and gradients locally on each device loss, grads = compute_loss_and_grads(params, x_shard, y_shard) # 2. Average gradients across all devices using pmean # 'batch' is the axis name we'll define in pmap # pmean calculates the mean across the devices mapped over the 'batch' axis avg_grads = pmean(grads, axis_name='batch') # 3. Update parameters (identically on all devices) # Simple gradient descent update new_params = jax.tree_map(lambda p, g: p - learning_rate * g, params, avg_grads) # We also average the loss for reporting purposes (optional, but useful) avg_loss = pmean(loss, axis_name='batch') return new_params, avg_loss # Use pmap to create the parallel version of the training step # - `axis_name='batch'` provides a name for the dimension being mapped over. # This name is used by collective operations like pmean inside the function. # - `in_axes=(0, 0, 0, None)` specifies how inputs are mapped: # - params: Replicated (None means use the same params on all devices) # - x_shard: Sharded along axis 0 (use the first dimension for devices) # - y_shard: Sharded along axis 0 # - learning_rate: Replicated (None means use the same value on all devices) # - `out_axes=0` specifies that the outputs (new_params, avg_loss) should be # stacked along axis 0. However, since the parameter update is identical # on all devices, all elements along axis 0 for `new_params` will be the same. # The avg_loss will also be identical across devices after pmean. # We use jit=True inside pmap for performance (often the default, but explicit here). p_train_step = pmap( distributed_train_step, axis_name='batch', in_axes=(0, 0, 0, None), out_axes=0, # Output params and loss will be sharded, but identical across devices static_broadcasted_argnums=(3,) # learning_rate doesn't change per call )Note: The static_broadcasted_argnums argument tells pmap (and jit) that the argument at that index (learning rate) is constant across calls with the same value. This can help avoid recompilation if only other arguments change. The in_axes specification is fundamental: 0 means split the corresponding argument along its first axis, distributing the resulting slices to the devices. None means replicate the argument; every device gets the same copy.Initialize Parameters and ReplicateBefore starting the training loop, we need to initialize our model parameters (w and b) and then replicate them across all devices. pmap expects inputs that are either sharded or replicated. Since parameters should be identical across devices at the start (and remain synchronized), we replicate them.# Initialize parameters randomly key, w_init_key, b_init_key = jax.random.split(key, 3) initial_w = jax.random.normal(w_init_key, (feature_dim,)) initial_b = jax.random.normal(b_init_key, ()) params = (initial_w, initial_b) # Replicate the initial parameters across all devices # Method 1: Using jax.tree_map and jnp.array stacking # replicated_params = jax.tree_map(lambda x: jnp.array([x] * num_devices), params) # Method 2: A common pattern using pmap's broadcast behavior # Define a helper function that just returns its input def broadcast(x): return x # pmap this function with in_axes=None, meaning the input x is broadcasted p_broadcast = pmap(broadcast, in_axes=None, out_axes=0) replicated_params = p_broadcast(params) print("Shape of initial w:", initial_w.shape) print("Shape of replicated w:", replicated_params[0].shape) # Should be (num_devices, feature_dim) print("Shape of initial b:", initial_b.shape) print("Shape of replicated b:", replicated_params[1].shape) # Should be (num_devices,)The Training LoopNow we can run the training loop. In each step, we call our p_train_step function with the currently replicated parameters and the sharded data. The function handles the parallel execution, gradient averaging via pmean, and synchronous update.learning_rate = 0.05 num_epochs = 50 # Using the full dataset sharded once per epoch here print("\nStarting distributed training...") current_params = replicated_params for epoch in range(num_epochs): # Execute the parallel training step # Pass the replicated parameters and the sharded data for this "epoch" current_params, loss = p_train_step(current_params, sharded_X, sharded_y, learning_rate) # The loss returned by p_train_step is replicated (identical on all devices # because we used pmean). We only need the value from one device. # jax.device_get transfers data from device to host (can be slow if done often). # Accessing the first element [0] also works and is often preferred within loops. epoch_loss = loss[0] # Get loss from the first device if (epoch + 1) % 10 == 0: print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}") # After training, the parameters in current_params are still replicated. # Get the final parameters from one device. final_params = jax.tree_map(lambda x: x[0], current_params) print("\nTraining finished.") print("Learned w:", final_params[0]) print("True w: ", true_w) print("Learned b:", final_params[1]) print("True b: ", true_b)VerificationAfter each call to p_train_step, the current_params variable holds the updated parameters, replicated across devices. Because the gradient averaging (pmean) ensures all devices compute the same average gradient, and the update rule is deterministic, the parameters on all devices should remain identical throughout training. You could verify this by checking jnp.allclose(current_params[0][0], current_params[0][1]) for the weights w between device 0 and 1, for example.SummaryIn this practice section, we successfully implemented a data-parallel training loop using pmap:Data Sharding: We reshaped our input data X and y so that the first dimension matched the number of devices.SPMD Function: We defined distributed_train_step, which encapsulates the logic for a single device: calculate loss and gradients on its data shard.Collective Communication: Inside this function, we used jax.lax.pmean with an axis_name ('batch') to average the gradients calculated across all devices participating in the pmap.Synchronous Update: Using the averaged gradient, each device performed the exact same parameter update, ensuring parameters remained synchronized.pmap Transformation: We applied pmap to distributed_train_step, specifying the axis_name and using in_axes to control how arguments were distributed (0 for sharded data, None for replicated parameters and learning rate).Replication: We explicitly replicated the initial parameters before starting the loop.This pattern forms the basis for scaling many JAX computations, especially deep learning training, across multiple accelerators. You can adapt this template by replacing the linear model and loss function with more complex neural networks and incorporating optimizers like Adam, while the core pmap structure for data parallelism remains largely the same.