Data parallelism is one of the most common and effective strategies for speeding up machine learning workloads, particularly model training. The core idea is straightforward: you replicate your model on multiple compute devices (like GPUs or TPUs) and feed each replica a different slice, or shard, of the input data batch. Each device processes its data shard independently using the same model parameters. pmap
is JAX's primary tool for implementing this SPMD (Single-Program Multiple-Data) pattern efficiently.
Recall that pmap
maps a function over arrays whose leading axis corresponds to the number of devices involved. When applied to data parallelism, this means:
pmap
(often your model's forward pass or the entire training step) is implicitly replicated across all participating devices. If the function arguments representing model parameters do not have a mapped leading axis, JAX automatically broadcasts them, ensuring each device gets the same copy.Let's illustrate this with a conceptual example. Suppose we have a simple prediction function and want to run it in parallel across available devices using data parallelism.
First, we need some setup: identify devices and define a function.
import jax
import jax.numpy as jnp
# Get the number of available devices
num_devices = jax.local_device_count()
print(f"Number of devices: {num_devices}")
# Example function (e.g., a simplified model layer)
def predict(params, inputs):
# A simple linear transformation
return jnp.dot(inputs, params['w']) + params['b']
# Generate dummy parameters (weights and bias)
# These will be replicated across devices
key = jax.random.PRNGKey(0)
input_dim = 10
output_dim = 5
params = {
'w': jax.random.normal(key, (input_dim, output_dim)),
'b': jax.random.normal(key, (output_dim,))
}
# Generate a global batch of data
global_batch_size = 32 * num_devices # Example total batch size
dummy_data = jax.random.normal(key, (global_batch_size, input_dim))
print(f"Parameter shapes: w={params['w'].shape}, b={params['b'].shape}")
print(f"Global data batch shape: {dummy_data.shape}")
Now, the critical step for data parallelism is preparing the input data. pmap
expects the input data array to have a leading dimension equal to the number of devices (num_devices
). We need to reshape our dummy_data
accordingly.
# Reshape data for pmap: [num_devices, batch_per_device, features]
batch_per_device = global_batch_size // num_devices
sharded_data = dummy_data.reshape((num_devices, batch_per_device, input_dim))
print(f"Sharded data shape: {sharded_data.shape}")
# Expected shape: (num_devices, batch_per_device, input_dim)
With the data correctly sharded, we can now apply pmap
. Notice that params
is passed directly. Since it doesn't have a leading dimension matching num_devices
, JAX understands it should be broadcast (copied) to each device. sharded_data
, however, has the correct leading dimension, so it will be split.
# Apply pmap to the predict function
# params are broadcast, sharded_data is split along the first axis
parallel_predict = jax.pmap(predict, in_axes=(None, 0))
# Run the parallel computation
# We don't need to explicitly replicate params, pmap handles broadcasting
sharded_predictions = parallel_predict(params, sharded_data)
# Ensure computation completes before checking shape
sharded_predictions.block_until_ready()
print(f"Output predictions shape: {sharded_predictions.shape}")
# Expected shape: (num_devices, batch_per_device, output_dim)
The in_axes
argument specifies how pmap
should treat each input argument:
None
: Broadcast this argument. The same value is sent to all devices. This is typical for model parameters.0
: Map over the first axis (axis 0) of this argument. This means the array is split along axis 0, and each slice is sent to a different device. This is standard for input data in data parallelism.The output sharded_predictions
also has a leading dimension corresponding to the number of devices. Each slice sharded_predictions[i]
contains the result computed on device i
using its portion of the input data (sharded_data[i]
) and the replicated params
.
Data flow in
pmap
for data parallelism. The global batch is split (sharded) across devices. Model parameters are typically replicated (broadcast) to each device. Each device computes its result independently. Outputs are stacked along the device axis.
This example demonstrates the core mechanic of applying a function in parallel to sharded data using pmap
. In a typical training scenario, the predict
function would be part of a larger train_step
function that also calculates the loss and gradients. While the forward pass and loss calculation per device are independent, the gradients computed on each device usually need to be combined (e.g., averaged) across all devices before updating the model parameters. This crucial aggregation step requires collective communication primitives, which we will explore in the next section.
© 2025 ApX Machine Learning