jax.jitjitjitjitgradjax.gradgrad of grad)jax.value_and_grad)vmapjax.vmapin_axes, out_axes)vmapvmap with jit and gradvmappmapjax.pmapin_axes, out_axes)lax.psum, lax.pmean, etc.)pmap with other Transformationspmapped Functionsin_axes, out_axes)The jax.pmap function operates using a Single Program, Multiple Data (SPMD) model. This means the same Python function code is executed on multiple devices (like GPUs or TPU cores), but each device typically works on a different portion of the input data. A primary challenge in this model is determining how to specify which data goes to which device, and how the results from each device are combined. The in_axes and out_axes arguments of pmap provide the solution to these questions.
Think of in_axes and out_axes as instructions for JAX on how to handle the "multiple data" part of SPMD. They specify the axes along which your input arrays should be split and your output arrays should be stacked.
in_axesThe in_axes argument tells pmap how to distribute the input arguments of your function across the available devices. It specifies, for each argument, which axis should be split (or mapped).
in_axes is typically an integer, None, or a (potentially nested) structure (like a tuple, list, or dictionary) matching the structure of the function's arguments (a PyTree).0): If in_axes for a particular argument is 0, it means the first axis (axis 0) of that NumPy or JAX array will be split across the devices. If you have N devices, an array of shape (B, ...) will be sliced into N chunks, each of shape (B/N, ...), and each chunk sent to a different device. This is the most common scenario for distributing batch data. If you specify 1, the second axis would be split, and so on.None Value: If in_axes for an argument is None, the entire argument is copied and made available to the function on every device. This is typically used for data that needs to be identical across all parallel computations, such as model parameters or shared configuration values.in_axes should be a tuple (or list, dict) whose structure mirrors the arguments. For example, if your function is def my_func(x, y): ..., you might use in_axes=(0, None). This would split the first argument x along its first axis but replicate the second argument y on all devices.Let's visualize splitting an array data with shape (8, 100) across 4 devices using in_axes=0:
Data distribution with
in_axes=0across 4 devices. The first axis of the input array is split evenly.
If an argument doesn't have the specified mapping axis (e.g., in_axes=0 for a 0-dimensional scalar), or if the size of the mapped axis isn't divisible by the number of devices, JAX will raise an error.
out_axesJust as in_axes controls how inputs are distributed, out_axes controls how the results returned by the function from each device are combined back into a single output value on the host.
in_axes, out_axes is typically an integer, None, or a PyTree structure matching the function's return value(s).0): If out_axes is 0, the outputs from each device are stacked together along a new axis 0. If each device produced an array of shape (S, ...), the final combined result will have shape (N, S, ...) where N is the number of devices. Often, if the input was split along axis 0 (in_axes=0), you'll want to stack the outputs along axis 0 (out_axes=0) to reconstruct the full batch dimension.None Value: If out_axes is None, JAX assumes the output value is identical across all devices. This is common when the function computes a value that has already been aggregated across devices (e.g., using a collective operation like lax.psum or lax.pmean, which we'll cover next). In this case, JAX simply returns the output from the first device.return loss, accuracy), out_axes should be a tuple (or list, dict) specifying how each returned element should be handled, like out_axes=(0, None).Continuing the previous example, assume the function on each device processes its (2, 100) slice and returns a result of shape (2, 50). Using out_axes=0:
Output gathering with
out_axes=0. Results from each device are stacked along the first axis.
Let's consider a simple function and how in_axes and out_axes control pmap. Assume we have 2 devices available.
import jax
import jax.numpy as jnp
import numpy as np
# Assume 2 devices are available for this example
# You can check available devices with jax.local_device_count()
# Example function: scales input x by a scalar factor 'k'
def scale(x, k):
return x * k
# Input data: 4 items, feature size 3
data = jnp.arange(12, dtype=jnp.float32).reshape((4, 3))
# Scalar factor
scalar = jnp.float32(10.0)
# pmap the function
# Split 'data' (x) along axis 0
# Replicate 'scalar' (k) on both devices
# Stack the results along axis 0
pmapped_scale = jax.pmap(scale, in_axes=(0, None), out_axes=0)
# Execute the pmapped function
result = pmapped_scale(data, scalar)
print("Number of devices:", jax.local_device_count()) # Example output: 2
print("Original data shape:", data.shape)
print("Scalar value:", scalar)
print("Pmapped result shape:", result.shape)
print("Pmapped result:\n", result)
# Expected Output (if run on 2 devices):
# Number of devices: 2
# Original data shape: (4, 3)
# Scalar value: 10.0
# Pmapped result shape: (4, 3)
# Pmapped result:
# [[ 0. 10. 20.]
# [ 30. 40. 50.]
# [ 60. 70. 80.]
# [ 90. 100. 110.]]
In this example:
data (shape (4, 3)) has in_axes=0. On 2 devices, Device 0 gets data[0:2, :] (shape (2, 3)) and Device 1 gets data[2:4, :] (shape (2, 3)).scalar (shape ()) has in_axes=None. Both devices receive the value 10.0.scale function runs on each device with its slice of x and the replicated k. Device 0 computes data[0:2, :] * 10.0, Device 1 computes data[2:4, :] * 10.0. Each produces a result of shape (2, 3).out_axes=0 instructs JAX to stack these (2, 3) results along the first axis, yielding the final (4, 3) result.Understanding in_axes and out_axes is fundamental to controlling data flow in pmap. By correctly specifying how data should be distributed and results gathered, you can effectively leverage multiple devices for parallel computation, particularly for common data parallelism patterns in machine learning. In the next section, we will look at collective operations, which allow communication between devices within a pmapped computation.
Was this section helpful?
jax.pmap documentation, The JAX Authors, 2024 - Official API reference for jax.pmap, detailing its arguments like in_axes and out_axes.pmap.© 2026 ApX Machine LearningEngineered with