jax.jitjitjitjitgradjax.gradgrad of grad)jax.value_and_grad)vmapjax.vmapin_axes, out_axes)vmapvmap with jit and gradvmappmapjax.pmapin_axes, out_axes)lax.psum, lax.pmean, etc.)pmap with other Transformationspmapped FunctionsFor parallelizing computations across multiple devices in JAX, pmap is a primary tool. While pmap enables basic data parallelism by mapping data along a single axis (controlled by in_axes and out_axes), this approach is often insufficient for more complex hardware topologies or advanced parallelization strategies. To effectively manage these complex scenarios, JAX introduces the concepts of device meshes and named axes.
Think of the available hardware accelerators (CPUs, GPUs, TPUs) not just as a flat list, but as potentially organized into a logical grid or mesh. This mesh provides a way to structure and address the devices. For example, if you have 8 TPUs, you might think of them as:
[Device 0, Device 1, ..., Device 7][[Device 0, Device 1], [Device 2, Device 3], [Device 4, Device 5], [Device 6, Device 7]] or [[Device 0, ..., Device 3], [Device 4, ..., Device 7]]JAX can report the available devices:
import jax
# See the list of available devices
print(jax.devices())
# Example Output (might differ based on your hardware):
# [CpuDevice(id=0)]
# or for multiple GPUs:
# [GpuDevice(id=0), GpuDevice(id=1), GpuDevice(id=2), GpuDevice(id=3)]
# or for TPUs:
# [TpuDevice(id=0), TpuDevice(id=1), ..., TpuDevice(id=7)]
When you use pmap, it implicitly operates over a mesh formed by the devices you're running on. For a standard pmap call distributing data along one dimension, it treats the devices as a 1D mesh.
The concept of a mesh becomes more significant when you need to coordinate operations across specific groups of devices or implement more sophisticated parallel patterns like model parallelism combined with data parallelism. While JAX offers explicit mesh management tools (jax.experimental.maps.Mesh), understanding the idea of a logical device arrangement is important even for basic pmap usage, especially when using collective operations.
Logical arrangements of 4 devices as a 1D or 2D mesh.
Imagine you have your devices arranged in a mesh. How do you tell JAX which dimension of this mesh corresponds to the data you are splitting with pmap? And how do you coordinate operations across devices lying along a specific dimension of the mesh? This is where axis names come in.
When defining a pmap-transformed function, you can assign a name to the axis of the device mesh over which the parallelism occurs using the axis_name argument.
import jax
import jax.numpy as jnp
# Assume we have 4 devices
def simple_computation(x):
# Some operation...
return x * 2
# Map the computation over devices, naming the device axis 'data_parallel_axis'
# We split the input array 'data' along its first axis (axis 0)
# across the named device axis.
data = jnp.arange(4 * 5).reshape((4, 5)) # Shape (4, 5) -> one row per device
# pmap implicitly creates a 1D mesh of size 4 here.
# 'data_parallel_axis' names this single dimension of the device mesh.
# in_axes=0 means the first axis of 'data' is mapped to this named device axis.
parallel_computation = jax.pmap(simple_computation, axis_name='data_parallel_axis', in_axes=0)
result = parallel_computation(data)
print(result.shape) # Output: (4, 5) - shape is maintained, but computation ran in parallel
print(jax.devices()) # Shows the devices used
# Example result on 4 devices:
# [[ 0 2 4 6 8] <- Computed on device 0
# [10 12 14 16 18] <- Computed on device 1
# [20 22 24 26 28] <- Computed on device 2
# [30 32 34 36 38]] <- Computed on device 3
Why use axis_name?
'data_parallel_axis' clearly indicates data parallelism. You might use 'model_parallel_axis' in more complex scenarios.'data_parallel_axis', regardless of whether that axis corresponds to 4 GPUs, 8 TPUs, or a different hardware setup.pmap (like summing results from all devices) need to know which group of devices to communicate with. The axis_name specifies this group. For example, lax.psum(x, axis_name='data_parallel_axis') would sum the value x across all devices participating in the parallel computation identified by 'data_parallel_axis'.Think of in_axes and out_axes as specifying how data dimensions map onto the device mesh dimensions, while axis_name gives a specific name to a device mesh dimension being used by that pmap instance. This name is then used internally, particularly for coordinating collective communication, which we will cover in the next section.
Even if you only have one dimension of parallelism (the common case for pure data parallelism), naming the axis is good practice and becomes essential as soon as you need devices to communicate during the parallel execution.
Was this section helpful?
jax.experimental.maps.Mesh.pmap transformation, explaining in_axes, out_axes, and axis_name parameters and their direct effect on parallel execution.© 2026 ApX Machine LearningEngineered with