As introduced earlier, jax.pmap
operates using a Single Program, Multiple Data (SPMD) model. This means the same Python function code is executed on multiple devices (like GPUs or TPU cores), but each device typically works on a different portion of the input data. The crucial question then becomes: how do we tell JAX which data goes to which device, and how are the results from each device combined? This is where the in_axes
and out_axes
arguments of pmap
come into play.
Think of in_axes
and out_axes
as instructions for JAX on how to handle the "multiple data" part of SPMD. They specify the axes along which your input arrays should be split and your output arrays should be stacked.
in_axes
The in_axes
argument tells pmap
how to distribute the input arguments of your function across the available devices. It specifies, for each argument, which axis should be split (or mapped).
in_axes
is typically an integer, None
, or a (potentially nested) structure (like a tuple, list, or dictionary) matching the structure of the function's arguments (a PyTree).0
): If in_axes
for a particular argument is 0
, it means the first axis (axis 0) of that NumPy or JAX array will be split across the devices. If you have N devices, an array of shape (B, ...)
will be sliced into N chunks, each of shape (B/N, ...)
, and each chunk sent to a different device. This is the most common scenario for distributing batch data. If you specify 1
, the second axis would be split, and so on.None
Value: If in_axes
for an argument is None
, the entire argument is copied and made available to the function on every device. This is typically used for data that needs to be identical across all parallel computations, such as model parameters or shared configuration values.in_axes
should be a tuple (or list, dict) whose structure mirrors the arguments. For example, if your function is def my_func(x, y): ...
, you might use in_axes=(0, None)
. This would split the first argument x
along its first axis but replicate the second argument y
on all devices.Let's visualize splitting an array data
with shape (8, 100)
across 4 devices using in_axes=0
:
Data distribution with
in_axes=0
across 4 devices. The first axis of the input array is split evenly.
If an argument doesn't have the specified mapping axis (e.g., in_axes=0
for a 0-dimensional scalar), or if the size of the mapped axis isn't divisible by the number of devices, JAX will raise an error.
out_axes
Just as in_axes
controls how inputs are distributed, out_axes
controls how the results returned by the function from each device are combined back into a single output value on the host.
in_axes
, out_axes
is typically an integer, None
, or a PyTree structure matching the function's return value(s).0
): If out_axes
is 0
, the outputs from each device are stacked together along a new axis 0. If each device produced an array of shape (S, ...)
, the final combined result will have shape (N, S, ...)
where N is the number of devices. Often, if the input was split along axis 0 (in_axes=0
), you'll want to stack the outputs along axis 0 (out_axes=0
) to reconstruct the full batch dimension.None
Value: If out_axes
is None
, JAX assumes the output value is identical across all devices. This is common when the function computes a value that has already been aggregated across devices (e.g., using a collective operation like lax.psum
or lax.pmean
, which we'll cover next). In this case, JAX simply returns the output from the first device.return loss, accuracy
), out_axes
should be a tuple (or list, dict) specifying how each returned element should be handled, like out_axes=(0, None)
.Continuing the previous example, assume the function on each device processes its (2, 100)
slice and returns a result of shape (2, 50)
. Using out_axes=0
:
Output gathering with
out_axes=0
. Results from each device are stacked along the first axis.
Let's consider a simple function and how in_axes
and out_axes
control pmap
. Assume we have 2 devices available.
import jax
import jax.numpy as jnp
import numpy as np
# Assume 2 devices are available for this example
# You can check available devices with jax.local_device_count()
# Example function: scales input x by a scalar factor 'k'
def scale(x, k):
return x * k
# Input data: 4 items, feature size 3
data = jnp.arange(12, dtype=jnp.float32).reshape((4, 3))
# Scalar factor
scalar = jnp.float32(10.0)
# pmap the function
# Split 'data' (x) along axis 0
# Replicate 'scalar' (k) on both devices
# Stack the results along axis 0
pmapped_scale = jax.pmap(scale, in_axes=(0, None), out_axes=0)
# Execute the pmapped function
result = pmapped_scale(data, scalar)
print("Number of devices:", jax.local_device_count()) # Example output: 2
print("Original data shape:", data.shape)
print("Scalar value:", scalar)
print("Pmapped result shape:", result.shape)
print("Pmapped result:\n", result)
# Expected Output (if run on 2 devices):
# Number of devices: 2
# Original data shape: (4, 3)
# Scalar value: 10.0
# Pmapped result shape: (4, 3)
# Pmapped result:
# [[ 0. 10. 20.]
# [ 30. 40. 50.]
# [ 60. 70. 80.]
# [ 90. 100. 110.]]
In this example:
data
(shape (4, 3)
) has in_axes=0
. On 2 devices, Device 0 gets data[0:2, :]
(shape (2, 3)
) and Device 1 gets data[2:4, :]
(shape (2, 3)
).scalar
(shape ()
) has in_axes=None
. Both devices receive the value 10.0
.scale
function runs on each device with its slice of x
and the replicated k
. Device 0 computes data[0:2, :] * 10.0
, Device 1 computes data[2:4, :] * 10.0
. Each produces a result of shape (2, 3)
.out_axes=0
instructs JAX to stack these (2, 3)
results along the first axis, yielding the final (4, 3)
result.Understanding in_axes
and out_axes
is fundamental to controlling data flow in pmap
. By correctly specifying how data should be distributed and results gathered, you can effectively leverage multiple devices for parallel computation, particularly for common data parallelism patterns in machine learning. In the next section, we will look at collective operations, which allow communication between devices within a pmap
ped computation.
© 2025 ApX Machine Learning