As we discussed in the chapter introduction, while jax.jit
compiles your code for efficient execution on a single accelerator and jax.vmap
provides automatic vectorization, distributing computations across multiple devices like GPUs or TPUs requires a different tool: jax.pmap
.
pmap
stands for "parallel map". It's a function transformation designed specifically for the Single Program, Multiple Data (SPMD) paradigm. In the SPMD model, you write your program (your function) once, and pmap
arranges for it to be executed simultaneously on multiple devices. Each device receives and processes a different portion of the input data. This is a common and effective way to achieve data parallelism, speeding up computations by dividing the workload.
At its core, using pmap
is syntactically similar to using jit
or vmap
. You can apply it as a decorator or call it directly on your function. Let's start with a simple example.
First, let's see what devices JAX can detect. This is important because pmap
needs multiple devices to distribute work across.
import jax
import jax.numpy as jnp
# Check available devices
print(f"Available JAX devices: {jax.devices()}")
num_devices = len(jax.devices())
print(f"Number of devices: {num_devices}")
If you are running this on a machine with multiple GPUs or on a TPU pod, jax.devices()
will list them. If you only have a CPU or a single GPU, pmap
will still work, but it will run all computations on that single device, essentially mimicking vmap
(though less efficiently for this purpose). The real benefits appear when num_devices > 1
.
Now, let's define a simple function and apply pmap
:
# A simple function to apply pmap to
def simple_computation(x):
return x * 2
# Apply pmap
pmapped_computation = jax.pmap(simple_computation)
# Prepare input data: Needs a leading axis matching the number of devices
try:
# Create data sharded across devices
# Example: If num_devices is 4, create an array of shape (4, ...)
input_data = jnp.arange(num_devices * 3).reshape((num_devices, 3))
print(f"Input data shape: {input_data.shape}")
# Execute the pmapped function
result = pmapped_computation(input_data)
print(f"Output result:\n{result}")
print(f"Output type: {type(result)}")
print(f"Output shape: {result.shape}")
print(f"Output devices: {result.devices()}")
except Exception as e:
print(f"Error during pmap execution: {e}")
print("Note: pmap typically requires the size of the mapped axis")
print("to be equal to the number of available devices.")
print("If running with only 1 device, this example might not showcase parallelism.")
You'll notice a few important things here:
input_data
needed a leading dimension whose size equals num_devices
. Here, we created an array of shape (num_devices, 3)
. pmap
automatically splits this array along the first axis (axis 0 by default) and sends each slice (of shape (3,)
in this case) to a different device.simple_computation
on its own slice of the data concurrently.result
is typically a ShardedDeviceArray
(or similar distributed array type). This indicates that the result data also physically resides distributed across the devices. Its shape mirrors the input shape (num_devices, 3)
, and you can confirm its distribution using result.devices()
.pmap
Works (Conceptually)Think of pmap
as doing the following:
jit
, pmap
first compiles your Python function (using XLA) into an optimized executable program suitable for the target hardware (GPU/TPU).pmap
ped function, JAX takes the input array(s) and splits them along a specified axis (defaulting to axis 0). Each chunk is sent to one of the devices.ShardedDeviceArray
.Here's a simplified visual representation:
A conceptual diagram showing how
pmap
splits input data along the first axis, sends each slice to a different device for parallel execution of the same program, and gathers the results into a distributed array.
pmap
vs vmap
It's useful to contrast pmap
with vmap
:
vmap
(Vectorization): Transforms a function that works on single examples into one that works on batches of examples conceptually within a single computation graph. It achieves parallelism through vectorization instructions on a single device. It doesn't inherently distribute work across multiple physical devices.pmap
(Parallelization): Explicitly replicates the computation across multiple physical devices. It executes the same function on different slices of data in parallel on different accelerators.While both can process batches of data, pmap
is the tool for scaling your computation beyond a single device by leveraging multiple available accelerators.
In the following sections, we will examine how to control which axes are mapped, how to handle multiple arguments, and how to perform communication between devices using collective operations within pmap
ped functions.
© 2025 ApX Machine Learning