The Single-Program Multiple-Data (SPMD) model is a widely used paradigm in parallel computing, especially suitable for accelerators like GPUs and TPUs. The core idea is straightforward: you write a single program, and this exact program executes simultaneously across multiple processors or devices. However, each instance of the program operates on a different subset of the overall data. This contrasts with other models like Multiple-Program Multiple-Data (MPMD), where different programs might run on different processors. For many machine learning tasks, particularly data parallelism, SPMD is a natural fit.
In JAX, the primary tool for implementing SPMD-style parallelism across multiple devices is jax.pmap
(parallel map). Conceptually, pmap
transforms a Python function written for a single device into one that executes in parallel across multiple devices (like the GPUs or TPU cores available to your JAX process). It handles the replication of the computation and the distribution (sharding) of data automatically.
Think of pmap
as being analogous to Python's built-in map
function, but instead of mapping a function over elements of a list in sequence, pmap
maps a function over devices in parallel. Each device executes the same compiled function but receives a unique slice of the input data.
pmap
Enables SPMDLet's illustrate the SPMD concept with pmap
. Imagine you have 4 TPU cores and a batch of data you want to process.
pmap
: You apply jax.pmap
to this function.(128, 50)
(batch size 128, feature size 50) and 4 devices, you would typically reshape or ensure your data loading provides it as (4, 32, 50)
. The leading dimension (size 4) represents the device axis.pmap
-transformed function with this sharded data, JAX does the following:
jit
).The following diagram provides a conceptual view:
Each device executes the same compiled code but operates on its assigned slice of the input data. Outputs are typically gathered back, stacked along a new device axis.
Let's see a concrete example. We'll define a simple function and apply it across multiple devices using pmap
. First, ensure JAX can see your available devices.
import jax
import jax.numpy as jnp
# Check available devices (CPUs, GPUs, or TPU cores)
num_devices = jax.local_device_count()
print(f"Number of available devices: {num_devices}")
# Example: Use 4 devices if available, otherwise use the actual count
if num_devices >= 4:
num_devices_to_use = 4
else:
num_devices_to_use = num_devices
print(f"Using {num_devices_to_use} devices for pmap.")
# Create some example data, sharded across the device dimension
# Total batch size = num_devices * per_device_batch_size
per_device_batch_size = 8
feature_size = 16
global_batch_size = num_devices_to_use * per_device_batch_size
# Shape: (num_devices, per_device_batch_size, feature_size)
sharded_data = jnp.arange(global_batch_size * feature_size).reshape(
(num_devices_to_use, per_device_batch_size, feature_size)
)
print(f"Shape of sharded input data: {sharded_data.shape}")
# Define a simple function to apply per device
def simple_computation(x):
# Example: Scale and add a constant
return x * 2.0 + 1.0
# Apply pmap to the function
# By default, pmap assumes the first axis (axis 0) of inputs
# should be mapped across devices.
parallel_computation = jax.pmap(simple_computation)
# Execute the parallel computation
# JAX distributes the leading axis of sharded_data across devices
result = parallel_computation(sharded_data)
# The output is also sharded across the leading axis
print(f"Shape of the output: {result.shape}")
# Verify a value from one device's output (e.g., first element on device 0)
# Original value was 0. Computation is 0 * 2.0 + 1.0 = 1.0
print(f"Result[0, 0, 0]: {result[0, 0, 0]}")
In this example:
sharded_data
where the first dimension matches the number of devices we intend to use. Each slice sharded_data[i]
will be sent to device i
.simple_computation
, which operates on a single slice of data.jax.pmap(simple_computation)
creates parallel_computation
, a new function ready for SPMD execution.parallel_computation(sharded_data)
triggers the parallel execution. Each device runs simple_computation
on its corresponding data slice sharded_data[i]
.result
has the same shape as the input, with the leading axis representing the devices. result[i]
contains the output computed by device i
.This demonstrates the essence of SPMD with pmap
: define the per-device logic, let pmap
handle the replication and parallel execution across devices by mapping over the leading axis of the input arrays. The underlying XLA compilation ensures the core computation is optimized for the target hardware. In the following sections, we will explore how to handle replicated data (like model parameters) using in_axes
and how to perform communication between devices using collective operations.
© 2025 ApX Machine Learning