Let's put the concepts from this chapter into practice by building a basic data-parallel training loop. We'll simulate having multiple devices (like GPUs or TPU cores) and distribute a simple machine learning task across them using pmap
. Our goal is to train a model where each device processes a portion of the data batch, calculates gradients locally, and then collaborates to compute the average gradient for a synchronous parameter update.
This exercise demonstrates the core workflow of Single-Program Multiple-Data (SPMD) execution with pmap
: write the code for one device, and pmap
handles replicating the execution and coordinating necessary communication.
First, let's import JAX, NumPy, and check the number of available devices. For this example, we'll simulate multiple devices even if you're running on a CPU, but the code works seamlessly on multi-GPU or TPU setups.
import jax
import jax.numpy as jnp
import numpy as np
from jax import pmap, grad, value_and_grad, jit
from jax.lax import pmean
# Check available devices
num_devices = jax.local_device_count()
print(f"Number of available devices: {num_devices}")
# If running on CPU, JAX can simulate multiple devices for pmap testing
# Uncomment the following line to simulate, e.g., 4 devices on a CPU
# jax.config.update('jax_platforms', 'cpu') # Force CPU usage if needed
# jax.config.update("jax_cpu_device_count", 4)
# num_devices = jax.local_device_count() # Update num_devices if simulating
# print(f"Number of simulated devices: {num_devices}")
# Ensure we have at least 2 devices for a meaningful example
if num_devices < 2:
print("Warning: This example is best run with multiple devices (real or simulated).")
# You can still run it, but pmap won't provide parallelism benefits.
# Generate a PRNG key
key = jax.random.PRNGKey(0)
We'll use a basic linear regression model for simplicity. The task is to find weights w
and bias b
such that y≈Xw+b.
def linear_model(params, x):
"""A simple linear model prediction function."""
w, b = params
return jnp.dot(x, w) + b
def mean_squared_error(params, x_batched, y_batched):
"""Calculates the mean squared error loss."""
predictions = linear_model(params, x_batched)
error = predictions - y_batched
loss = jnp.mean(error**2)
return loss
We need data that can be split across our devices. The standard practice for data parallelism with pmap
is to ensure the first dimension of your input arrays corresponds to the number of devices. Each slice along this dimension (data[i]
) is sent to the i
-th device.
# Generate synthetic data
feature_dim = 5
num_samples = 100 * num_devices # Ensure total samples are divisible by num_devices
key, w_key, b_key, x_key, noise_key = jax.random.split(key, 5)
# True parameters (we want the model to learn these)
true_w = jax.random.normal(w_key, (feature_dim,))
true_b = jax.random.normal(b_key, ())
# Generate features X and targets y
X = jax.random.normal(x_key, (num_samples, feature_dim))
noise = jax.random.normal(noise_key, (num_samples,)) * 0.1
y = jnp.dot(X, true_w) + true_b + noise
# Reshape data for pmap: add a leading dimension for devices
# Each device will get batch_size_per_device = num_samples // num_devices samples
batch_size_per_device = num_samples // num_devices
sharded_X = X.reshape((num_devices, batch_size_per_device, feature_dim))
sharded_y = y.reshape((num_devices, batch_size_per_device))
print(f"Total samples: {num_samples}")
print(f"Samples per device: {batch_size_per_device}")
print(f"Shape of sharded X: {sharded_X.shape}") # Should be (num_devices, batch_size_per_device, feature_dim)
print(f"Shape of sharded y: {sharded_y.shape}") # Should be (num_devices, batch_size_per_device)
This is the core of the example. We'll create a function designed to run on a single device, calculating the loss and gradients for its local shard of data. Then, we'll use pmap
to run this function in parallel across all devices. Crucially, inside the pmap
'd function, we'll use jax.lax.pmean
to average the gradients computed independently on each device before performing the parameter update.
# Define the function to calculate loss and gradients on one device
compute_loss_and_grads = value_and_grad(mean_squared_error)
# Define the training step function that will be pmapped
# It takes the current parameters and a shard of data for one device
def distributed_train_step(params, x_shard, y_shard, learning_rate):
"""Performs one training step on sharded data across devices."""
# 1. Calculate loss and gradients locally on each device
loss, grads = compute_loss_and_grads(params, x_shard, y_shard)
# 2. Average gradients across all devices using pmean
# 'batch' is the axis name we'll define in pmap
# pmean calculates the mean across the devices mapped over the 'batch' axis
avg_grads = pmean(grads, axis_name='batch')
# 3. Update parameters (identically on all devices)
# Simple gradient descent update
new_params = jax.tree_map(lambda p, g: p - learning_rate * g, params, avg_grads)
# We also average the loss for reporting purposes (optional, but useful)
avg_loss = pmean(loss, axis_name='batch')
return new_params, avg_loss
# Use pmap to create the parallel version of the training step
# - `axis_name='batch'` provides a name for the dimension being mapped over.
# This name is used by collective operations like pmean inside the function.
# - `in_axes=(0, 0, 0, None)` specifies how inputs are mapped:
# - params: Replicated (None means use the same params on all devices)
# - x_shard: Sharded along axis 0 (use the first dimension for devices)
# - y_shard: Sharded along axis 0
# - learning_rate: Replicated (None means use the same value on all devices)
# - `out_axes=0` specifies that the outputs (new_params, avg_loss) should be
# stacked along axis 0. However, since the parameter update is identical
# on all devices, all elements along axis 0 for `new_params` will be the same.
# The avg_loss will also be identical across devices after pmean.
# We use jit=True inside pmap for performance (often the default, but explicit here).
p_train_step = pmap(
distributed_train_step,
axis_name='batch',
in_axes=(0, 0, 0, None),
out_axes=0, # Output params and loss will be sharded, but identical across devices
static_broadcasted_argnums=(3,) # learning_rate doesn't change per call
)
Note: The
static_broadcasted_argnums
argument tellspmap
(andjit
) that the argument at that index (learning rate) is constant across calls with the same value. This can help avoid recompilation if only other arguments change. Thein_axes
specification is fundamental:0
means split the corresponding argument along its first axis, distributing the resulting slices to the devices.None
means replicate the argument; every device gets the same copy.
Before starting the training loop, we need to initialize our model parameters (w
and b
) and then replicate them across all devices. pmap
expects inputs that are either sharded or replicated. Since parameters should be identical across devices at the start (and remain synchronized), we replicate them.
# Initialize parameters randomly
key, w_init_key, b_init_key = jax.random.split(key, 3)
initial_w = jax.random.normal(w_init_key, (feature_dim,))
initial_b = jax.random.normal(b_init_key, ())
params = (initial_w, initial_b)
# Replicate the initial parameters across all devices
# Method 1: Using jax.tree_map and jnp.array stacking
# replicated_params = jax.tree_map(lambda x: jnp.array([x] * num_devices), params)
# Method 2: A common pattern using pmap's broadcast behavior
# Define a helper function that just returns its input
def broadcast(x):
return x
# pmap this function with in_axes=None, meaning the input x is broadcasted
p_broadcast = pmap(broadcast, in_axes=None, out_axes=0)
replicated_params = p_broadcast(params)
print("Shape of initial w:", initial_w.shape)
print("Shape of replicated w:", replicated_params[0].shape) # Should be (num_devices, feature_dim)
print("Shape of initial b:", initial_b.shape)
print("Shape of replicated b:", replicated_params[1].shape) # Should be (num_devices,)
Now we can run the training loop. In each step, we call our p_train_step
function with the currently replicated parameters and the sharded data. The function handles the parallel execution, gradient averaging via pmean
, and synchronous update.
learning_rate = 0.05
num_epochs = 50 # Using the full dataset sharded once per epoch here
print("\nStarting distributed training...")
current_params = replicated_params
for epoch in range(num_epochs):
# Execute the parallel training step
# Pass the replicated parameters and the sharded data for this "epoch"
current_params, loss = p_train_step(current_params, sharded_X, sharded_y, learning_rate)
# The loss returned by p_train_step is replicated (identical on all devices
# because we used pmean). We only need the value from one device.
# jax.device_get transfers data from device to host (can be slow if done often).
# Accessing the first element [0] also works and is often preferred within loops.
epoch_loss = loss[0] # Get loss from the first device
if (epoch + 1) % 10 == 0:
print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")
# After training, the parameters in current_params are still replicated.
# Get the final parameters from one device.
final_params = jax.tree_map(lambda x: x[0], current_params)
print("\nTraining finished.")
print("Learned w:", final_params[0])
print("True w: ", true_w)
print("Learned b:", final_params[1])
print("True b: ", true_b)
After each call to p_train_step
, the current_params
variable holds the updated parameters, replicated across devices. Because the gradient averaging (pmean
) ensures all devices compute the same average gradient, and the update rule is deterministic, the parameters on all devices should remain identical throughout training. You could verify this by checking jnp.allclose(current_params[0][0], current_params[0][1])
for the weights w
between device 0 and 1, for example.
In this practice section, we successfully implemented a data-parallel training loop using pmap
:
X
and y
so that the first dimension matched the number of devices.distributed_train_step
, which encapsulates the logic for a single device: calculate loss and gradients on its data shard.jax.lax.pmean
with an axis_name
('batch') to average the gradients calculated across all devices participating in the pmap
.pmap
Transformation: We applied pmap
to distributed_train_step
, specifying the axis_name
and using in_axes
to control how arguments were distributed (0
for sharded data, None
for replicated parameters and learning rate).This pattern forms the basis for scaling many JAX computations, especially deep learning training, across multiple accelerators. You can adapt this template by replacing the linear model and loss function with more complex neural networks and incorporating optimizers like Adam, while the core pmap
structure for data parallelism remains largely the same.
© 2025 ApX Machine Learning