Data parallelism is one of the most common and effective strategies for speeding up machine learning workloads, particularly model training. The core idea is straightforward: you replicate your model on multiple compute devices (like GPUs or TPUs) and feed each replica a different slice, or shard, of the input data batch. Each device processes its data shard independently using the same model parameters. pmap is JAX's primary tool for implementing this SPMD (Single-Program Multiple-Data) pattern efficiently.
Recall that pmap maps a function over arrays whose leading axis corresponds to the number of devices involved. When applied to data parallelism, this means:
pmap (often your model's forward pass or the entire training step) is implicitly replicated across all participating devices. If the function arguments representing model parameters do not have a mapped leading axis, JAX automatically broadcasts them, ensuring each device gets the same copy.Let's illustrate this with an example. Suppose we have a simple prediction function and want to run it in parallel across available devices using data parallelism.
First, we need some setup: identify devices and define a function.
import jax
import jax.numpy as jnp
# Get the number of available devices
num_devices = jax.local_device_count()
print(f"Number of devices: {num_devices}")
# Example function (e.g., a simplified model layer)
def predict(params, inputs):
# A simple linear transformation
return jnp.dot(inputs, params['w']) + params['b']
# Generate dummy parameters (weights and bias)
# These will be replicated across devices
key = jax.random.PRNGKey(0)
input_dim = 10
output_dim = 5
params = {
'w': jax.random.normal(key, (input_dim, output_dim)),
'b': jax.random.normal(key, (output_dim,))
}
# Generate a global batch of data
global_batch_size = 32 * num_devices # Example total batch size
dummy_data = jax.random.normal(key, (global_batch_size, input_dim))
print(f"Parameter shapes: w={params['w'].shape}, b={params['b'].shape}")
print(f"Global data batch shape: {dummy_data.shape}")
Now, the critical step for data parallelism is preparing the input data. pmap expects the input data array to have a leading dimension equal to the number of devices (num_devices). We need to reshape our dummy_data accordingly.
# Reshape data for pmap: [num_devices, batch_per_device, features]
batch_per_device = global_batch_size // num_devices
sharded_data = dummy_data.reshape((num_devices, batch_per_device, input_dim))
print(f"Sharded data shape: {sharded_data.shape}")
# Expected shape: (num_devices, batch_per_device, input_dim)
With the data correctly sharded, we can now apply pmap. Notice that params is passed directly. Since it doesn't have a leading dimension matching num_devices, JAX understands it should be broadcast (copied) to each device. sharded_data, however, has the correct leading dimension, so it will be split.
# Apply pmap to the predict function
# params are broadcast, sharded_data is split along the first axis
parallel_predict = jax.pmap(predict, in_axes=(None, 0))
# Run the parallel computation
# We don't need to explicitly replicate params, pmap handles broadcasting
sharded_predictions = parallel_predict(params, sharded_data)
# Ensure computation completes before checking shape
sharded_predictions.block_until_ready()
print(f"Output predictions shape: {sharded_predictions.shape}")
# Expected shape: (num_devices, batch_per_device, output_dim)
The in_axes argument specifies how pmap should treat each input argument:
None: Broadcast this argument. The same value is sent to all devices. This is typical for model parameters.0: Map over the first axis (axis 0) of this argument. This means the array is split along axis 0, and each slice is sent to a different device. This is standard for input data in data parallelism.The output sharded_predictions also has a leading dimension corresponding to the number of devices. Each slice sharded_predictions[i] contains the result computed on device i using its portion of the input data (sharded_data[i]) and the replicated params.
Data flow in
pmapfor data parallelism. The global batch is split (sharded) across devices. Model parameters are typically replicated (broadcast) to each device. Each device computes its result independently. Outputs are stacked along the device axis.
This example demonstrates the core mechanic of applying a function in parallel to sharded data using pmap. In a typical training scenario, the predict function would be part of a larger train_step function that also calculates the loss and gradients. While the forward pass and loss calculation per device are independent, the gradients computed on each device usually need to be combined (e.g., averaged) across all devices before updating the model parameters. This important aggregation step requires collective communication primitives, which we will explore in the next section.
Was this section helpful?
pmap, Vladimir Mikulik, Roman Ring, 2024 - Explains pmap for Single-Program Multiple-Data (SPMD) programming across multiple devices, including its in_axes argument and handling of data and parameter distribution.© 2026 ApX Machine LearningEngineered with