The Single-Program Multiple-Data (SPMD) model is a widely used approach in parallel computing, especially suitable for accelerators like GPUs and TPUs. The core idea is straightforward: you write a single program, and this exact program executes simultaneously across multiple processors or devices. However, each instance of the program operates on a different subset of the overall data. This contrasts with other models like Multiple-Program Multiple-Data (MPMD), where different programs might run on different processors. For many machine learning tasks, particularly data parallelism, SPMD is a natural fit.In JAX, the primary tool for implementing SPMD-style parallelism across multiple devices is jax.pmap (parallel map). pmap transforms a Python function written for a single device into one that executes in parallel across multiple devices (like the GPUs or TPU cores available to your JAX process). It handles the replication of the computation and the distribution (sharding) of data automatically.Think of pmap as being analogous to Python's built-in map function, but instead of mapping a function over elements of a list in sequence, pmap maps a function over devices in parallel. Each device executes the same compiled function but receives a unique slice of the input data.How pmap Enables SPMDLet's illustrate the SPMD concept with pmap. Imagine you have 4 TPU cores and a batch of data you want to process.Write Single-Device Code: You first write your JAX function as if it were running on a single device. This function might define a layer of a neural network, a loss calculation, or any other computation.Transform with pmap: You apply jax.pmap to this function.Data Sharding: You prepare your input data such that its leading axis corresponds to the number of devices. For instance, if you have a data array of shape (128, 50) (batch size 128, feature size 50) and 4 devices, you would typically reshape or ensure your data loading provides it as (4, 32, 50). The leading dimension (size 4) represents the device axis.Execution: When you call the pmap-transformed function with this sharded data, JAX does the following:It compiles the original function using XLA (if not already compiled via jit).It replicates the compiled XLA computation onto each of the 4 specified devices.It sends the corresponding slice of the input data to each device (device 0 gets data slice 0, device 1 gets data slice 1, etc.).All devices execute the computation simultaneously on their local data slice.The results from each device are gathered and stacked along a new leading axis in the output.The following diagram provides a view:digraph G { rankdir=LR; splines=false; node [shape=box, style=rounded, fontname="sans-serif", fillcolor="#e9ecef", style=filled]; subgraph cluster_pmap { label = "pmap'd Function Execution"; bgcolor="#f8f9fa"; style=filled; node [shape=Mrecord, fillcolor="#a5d8ff", style=filled]; Device0 [label="{Device 0 | Input Slice 0 | {Execute\nCompiled Code} | Output Slice 0}"]; Device1 [label="{Device 1 | Input Slice 1 | {Execute\nCompiled Code} | Output Slice 1}"]; DeviceN [label="{Device N | Input Slice N | {Execute\nCompiled Code} | Output Slice N}"]; } InputData [label="Sharded Input Data\n(N, ...)", shape=folder, fillcolor="#ffd8a8", style=filled]; OutputData [label="Stacked Output Data\n(N, ...)", shape=folder, fillcolor="#b2f2bb", style=filled]; CompiledCode [label="Single Compiled\nFunction (XLA)", shape=note, fillcolor="#eebefa", style=filled]; InputData -> Device0 [label="Shard 0"]; InputData -> Device1 [label="Shard 1"]; InputData -> DeviceN [label="Shard N"]; CompiledCode -> Device0 [style=dashed, color="#ae3ec9"]; CompiledCode -> Device1 [style=dashed, color="#ae3ec9"]; CompiledCode -> DeviceN [style=dashed, color="#ae3ec9"]; Device0 -> OutputData [label="Slice 0"]; Device1 -> OutputData [label="Slice 1"]; DeviceN -> OutputData [label="Slice N"]; {rank=same; Device0 Device1 DeviceN} }Each device executes the same compiled code but operates on its assigned slice of the input data. Outputs are typically gathered back, stacked along a new device axis.Basic Usage ExampleLet's see a concrete example. We'll define a simple function and apply it across multiple devices using pmap. First, ensure JAX can see your available devices.import jax import jax.numpy as jnp # Check available devices (CPUs, GPUs, or TPU cores) num_devices = jax.local_device_count() print(f"Number of available devices: {num_devices}") # Example: Use 4 devices if available, otherwise use the actual count if num_devices >= 4: num_devices_to_use = 4 else: num_devices_to_use = num_devices print(f"Using {num_devices_to_use} devices for pmap.") # Create some example data, sharded across the device dimension # Total batch size = num_devices * per_device_batch_size per_device_batch_size = 8 feature_size = 16 global_batch_size = num_devices_to_use * per_device_batch_size # Shape: (num_devices, per_device_batch_size, feature_size) sharded_data = jnp.arange(global_batch_size * feature_size).reshape( (num_devices_to_use, per_device_batch_size, feature_size) ) print(f"Shape of sharded input data: {sharded_data.shape}") # Define a simple function to apply per device def simple_computation(x): # Example: Scale and add a constant return x * 2.0 + 1.0 # Apply pmap to the function # By default, pmap assumes the first axis (axis 0) of inputs # should be mapped across devices. parallel_computation = jax.pmap(simple_computation) # Execute the parallel computation # JAX distributes the leading axis of sharded_data across devices result = parallel_computation(sharded_data) # The output is also sharded across the leading axis print(f"Shape of the output: {result.shape}") # Verify a value from one device's output (e.g., first element on device 0) # Original value was 0. Computation is 0 * 2.0 + 1.0 = 1.0 print(f"Result[0, 0, 0]: {result[0, 0, 0]}")In this example:We determine the number of devices JAX can access.We create input data sharded_data where the first dimension matches the number of devices we intend to use. Each slice sharded_data[i] will be sent to device i.We define simple_computation, which operates on a single slice of data.jax.pmap(simple_computation) creates parallel_computation, a new function ready for SPMD execution.Calling parallel_computation(sharded_data) triggers the parallel execution. Each device runs simple_computation on its corresponding data slice sharded_data[i].The output result has the same shape as the input, with the leading axis representing the devices. result[i] contains the output computed by device i.This demonstrates the essence of SPMD with pmap: define the per-device logic, let pmap handle the replication and parallel execution across devices by mapping over the leading axis of the input arrays. The underlying XLA compilation ensures the core computation is optimized for the target hardware. In the following sections, we will explore how to handle replicated data (like model parameters) using in_axes and how to perform communication between devices using collective operations.