While pmap
allows us to run computations in parallel across multiple devices, simply mapping data along one axis (as controlled by in_axes
and out_axes
) works well for basic data parallelism but doesn't capture the full picture, especially when dealing with more complex hardware topologies or parallelization strategies. To manage this complexity, JAX uses the concepts of device meshes and named axes.
Think of the available hardware accelerators (CPUs, GPUs, TPUs) not just as a flat list, but as potentially organized into a logical grid or mesh. This mesh provides a way to structure and address the devices. For example, if you have 8 TPUs, you might think of them as:
[Device 0, Device 1, ..., Device 7]
[[Device 0, Device 1], [Device 2, Device 3], [Device 4, Device 5], [Device 6, Device 7]]
or [[Device 0, ..., Device 3], [Device 4, ..., Device 7]]
JAX can report the available devices:
import jax
# See the list of available devices
print(jax.devices())
# Example Output (might differ based on your hardware):
# [CpuDevice(id=0)]
# or for multiple GPUs:
# [GpuDevice(id=0), GpuDevice(id=1), GpuDevice(id=2), GpuDevice(id=3)]
# or for TPUs:
# [TpuDevice(id=0), TpuDevice(id=1), ..., TpuDevice(id=7)]
When you use pmap
, it implicitly operates over a mesh formed by the devices you're running on. For a standard pmap
call distributing data along one dimension, it treats the devices as a 1D mesh.
The concept of a mesh becomes more significant when you need to coordinate operations across specific groups of devices or implement more sophisticated parallel patterns like model parallelism combined with data parallelism. While JAX offers explicit mesh management tools (jax.experimental.maps.Mesh
), understanding the idea of a logical device arrangement is important even for basic pmap
usage, especially when using collective operations.
Logical arrangements of 4 devices as a 1D or 2D mesh.
Imagine you have your devices arranged in a mesh. How do you tell JAX which dimension of this mesh corresponds to the data you are splitting with pmap
? And how do you coordinate operations across devices lying along a specific dimension of the mesh? This is where axis names come in.
When defining a pmap
-transformed function, you can assign a name to the axis of the device mesh over which the parallelism occurs using the axis_name
argument.
import jax
import jax.numpy as jnp
# Assume we have 4 devices
def simple_computation(x):
# Some operation...
return x * 2
# Map the computation over devices, naming the device axis 'data_parallel_axis'
# We split the input array 'data' along its first axis (axis 0)
# across the named device axis.
data = jnp.arange(4 * 5).reshape((4, 5)) # Shape (4, 5) -> one row per device
# pmap implicitly creates a 1D mesh of size 4 here.
# 'data_parallel_axis' names this single dimension of the device mesh.
# in_axes=0 means the first axis of 'data' is mapped to this named device axis.
parallel_computation = jax.pmap(simple_computation, axis_name='data_parallel_axis', in_axes=0)
result = parallel_computation(data)
print(result.shape) # Output: (4, 5) - shape is maintained, but computation ran in parallel
print(jax.devices()) # Shows the devices used
# Example result on 4 devices:
# [[ 0 2 4 6 8] <- Computed on device 0
# [10 12 14 16 18] <- Computed on device 1
# [20 22 24 26 28] <- Computed on device 2
# [30 32 34 36 38]] <- Computed on device 3
Why use axis_name
?
'data_parallel_axis'
clearly indicates data parallelism. You might use 'model_parallel_axis'
in more complex scenarios.'data_parallel_axis'
, regardless of whether that axis corresponds to 4 GPUs, 8 TPUs, or a different hardware setup.pmap
(like summing results from all devices) need to know which group of devices to communicate with. The axis_name
specifies this group. For example, lax.psum(x, axis_name='data_parallel_axis')
would sum the value x
across all devices participating in the parallel computation identified by 'data_parallel_axis'
.Think of in_axes
and out_axes
as specifying how data dimensions map onto the device mesh dimensions, while axis_name
gives a specific name to a device mesh dimension being used by that pmap
instance. This name is then used internally, particularly for coordinating collective communication, which we will cover in the next section.
Even if you only have one dimension of parallelism (the common case for pure data parallelism), naming the axis is good practice and becomes essential as soon as you need devices to communicate during the parallel execution.
© 2025 ApX Machine Learning