jax.pmapWhile jax.jit compiles code for efficient execution on a single accelerator and jax.vmap provides automatic vectorization, distributing computations across multiple devices like GPUs or TPUs requires jax.pmap.
pmap stands for "parallel map". It's a function transformation designed specifically for the Single Program, Multiple Data (SPMD) model. In the model, you write your program (your function) once, and pmap arranges for it to be executed simultaneously on multiple devices. Each device receives and processes a different portion of the input data. This is a common and effective way to achieve data parallelism, speeding up computations by dividing the workload.
At its core, using pmap is syntactically similar to using jit or vmap. You can apply it as a decorator or call it directly on your function. Let's start with a simple example.
First, let's see what devices JAX can detect. This is important because pmap needs multiple devices to distribute work across.
import jax
import jax.numpy as jnp
# Check available devices
print(f"Available JAX devices: {jax.devices()}")
num_devices = len(jax.devices())
print(f"Number of devices: {num_devices}")
If you are running this on a machine with multiple GPUs or on a TPU pod, jax.devices() will list them. If you only have a CPU or a single GPU, pmap will still work, but it will run all computations on that single device, essentially mimicking vmap (though less efficiently for this purpose). The real benefits appear when num_devices > 1.
Now, let's define a simple function and apply pmap:
# A simple function to apply pmap to
def simple_computation(x):
return x * 2
# Apply pmap
pmapped_computation = jax.pmap(simple_computation)
# Prepare input data: Needs a leading axis matching the number of devices
try:
# Create data sharded across devices
# Example: If num_devices is 4, create an array of shape (4, ...)
input_data = jnp.arange(num_devices * 3).reshape((num_devices, 3))
print(f"Input data shape: {input_data.shape}")
# Execute the pmapped function
result = pmapped_computation(input_data)
print(f"Output result:\n{result}")
print(f"Output type: {type(result)}")
print(f"Output shape: {result.shape}")
print(f"Output devices: {result.devices()}")
except Exception as e:
print(f"Error during pmap execution: {e}")
print("Note: pmap typically requires the size of the mapped axis")
print("to be equal to the number of available devices.")
print("If running with only 1 device, this example might not show parallelism.")
You'll notice a few important things here:
input_data needed a leading dimension whose size equals num_devices. Here, we created an array of shape (num_devices, 3). pmap automatically splits this array along the first axis (axis 0 by default) and sends each slice (of shape (3,) in this case) to a different device.simple_computation on its own slice of the data concurrently.result is typically a ShardedDeviceArray (or similar distributed array type). This indicates that the result data also physically resides distributed across the devices. Its shape mirrors the input shape (num_devices, 3), and you can confirm its distribution using result.devices().pmap WorksThink of pmap as doing the following:
jit, pmap first compiles your Python function (using XLA) into an optimized executable program suitable for the target hardware (GPU/TPU).pmapped function, JAX takes the input array(s) and splits them along a specified axis (defaulting to axis 0). Each chunk is sent to one of the devices.ShardedDeviceArray.Here's a simplified visual representation:
A diagram showing how
pmapsplits input data along the first axis, sends each slice to a different device for parallel execution of the same program, and gathers the results into a distributed array.
pmap vs vmapIt's useful to contrast pmap with vmap:
vmap (Vectorization): Transforms a function that works on single examples into one that works on batches of examples within a single computation graph. It achieves parallelism through vectorization instructions on a single device. It doesn't inherently distribute work across multiple physical devices.pmap (Parallelization): Explicitly replicates the computation across multiple physical devices. It executes the same function on different slices of data in parallel on different accelerators.While both can process batches of data, pmap is the tool for scaling your computation across a single device by using multiple available accelerators.
In the following sections, we will examine how to control which axes are mapped, how to handle multiple arguments, and how to perform communication between devices using collective operations within pmapped functions.
Was this section helpful?
jax.pmap.pmap and distributed computations.© 2026 ApX Machine LearningEngineered with