To distribute computations in JAX, jax.pmap allows each device to execute the same program on its local slice of the data, adhering to the Single Program, Multiple Data (SPMD) model. Distributed algorithms often require devices to communicate and synchronize. For example, in data-parallel training, each device calculates gradients based on its local data batch, but these gradients need to be aggregated (usually averaged) across all devices before updating the model parameters. This is where collective communication operations become essential.
Collective operations in JAX (jax.lax collectives) allow arrays distributed across multiple devices (along the mapped axis defined by pmap) to participate in a combined computation. The result of a collective operation is typically replicated back to all participating devices. These operations must be called inside the function being transformed by pmap.
axis_nameA fundamental aspect of using collectives within pmap is specifying which axis the communication should happen over. Since pmap itself creates a new mapped axis representing the devices, you need to tell the collective operation to use this specific axis. This is done using the axis_name argument, which must match the axis_name provided to the enclosing pmap function. This explicit naming prevents ambiguity, especially when dealing with nested pmap calls or other complex transformations.
psum, pmean, pmax, pminThe most common collectives perform reduction operations across devices.
jax.lax.psum (Parallel Sum)jax.lax.psum calculates the element-wise sum of an array across all devices participating in the mapped axis. Each device contributes its local version of the array, and every device receives the identical resulting array containing the total sum.
Consider a scenario where each of four devices holds a scalar value:
Aggregation using
psumacross four devices along the named axis 'devices'. Each device starts with a local value, and afterpsum, each device holds the total sum (2 + 3 + 1 + 4 = 10).
Here's how you might define a function using psum to be used with pmap:
import jax
import jax.numpy as jnp
# Function intended to be run via pmap
def sum_across_devices(local_value):
# 'batch_axis' must match the axis_name in the pmap call
total_sum = jax.lax.psum(local_value, axis_name='batch_axis')
# Each device now has the total sum.
# We could, for example, use it to scale local computations
return total_sum
# Usage (assuming 4 devices):
# values = jnp.arange(4.) # [0., 1., 2., 3.] -> one value per device
# pmapped_sum = jax.pmap(sum_across_devices, axis_name='batch_axis')
# result = pmapped_sum(values)
# # result would be DeviceArray([6., 6., 6., 6.], dtype=float32)
# # Each device gets 0+1+2+3 = 6
A common application is summing gradients in data-parallel training before the optimizer step.
jax.lax.pmean (Parallel Mean)jax.lax.pmean works similarly to psum, but instead of returning the sum, it returns the element-wise mean of the array across all devices along the named axis. It's equivalent to performing a psum and then dividing by the number of devices participating in that axis (which JAX tracks automatically).
import jax
import jax.numpy as jnp
# Function intended to be run via pmap
def average_across_devices(local_value):
# 'batch_axis' must match the axis_name in the pmap call
average_value = jax.lax.pmean(local_value, axis_name='batch_axis')
return average_value
# Usage (assuming 4 devices):
# values = jnp.arange(4.) # [0., 1., 2., 3.]
# pmapped_mean = jax.pmap(average_across_devices, axis_name='batch_axis')
# result = pmapped_mean(values)
# # result would be DeviceArray([1.5, 1.5, 1.5, 1.5], dtype=float32)
# # Each device gets (0+1+2+3)/4 = 1.5
pmean is the standard operation for averaging gradients in data parallelism. Using pmean directly is generally preferred over psum followed by manual division, as it can sometimes be implemented more efficiently by the backend.
jax.lax.pmax and jax.lax.pmin (Parallel Max/Min)These collectives compute the element-wise maximum (pmax) or minimum (pmin) across devices along the named axis. Every device receives the identical resulting array containing the maximum or minimum values found across all participants.
Use cases include synchronizing flags (e.g., did any device encounter an error?), finding the maximum loss observed across batches, or other forms of distributed coordination.
import jax
import jax.numpy as jnp
# Function intended to be run via pmap
def get_max_value(local_value):
max_val = jax.lax.pmax(local_value, axis_name='data_split')
return max_val
# Usage (assuming 4 devices):
# values = jnp.array([2., 5., 1., 4.]) # One value per device
# pmapped_max = jax.pmap(get_max_value, axis_name='data_split')
# result = pmapped_max(values)
# # result would be DeviceArray([5., 5., 5., 5.], dtype=float32)
While reduction collectives are very common, JAX provides others for different communication patterns.
jax.lax.all_gather: This operation gathers the input arrays from all devices and concatenates them along the mapped axis. Each device receives the full, concatenated array. This is useful if every device needs access to the data from all other devices, but be mindful that the output array size scales linearly with the number of devices, potentially consuming significant memory.
jax.lax.ppermute: This is a more general collective allowing devices to exchange data based on a permutation rule specified by the permute argument. It describes pairs of source/destination device indices for the data swap. ppermute is fundamental for implementing more complex parallelism schemes, such as ring-allreduce or model parallelism communication patterns, although using it directly requires careful index management.
jax.lax.axis_index: While not strictly a communication primitive, jax.lax.axis_index(axis_name) is often used in conjunction with collectives. It returns the integer index (ID) of the current device within the specified mapped axis. This allows devices to behave differently based on their position within the group, which can be useful for specific communication patterns or workload balancing.
Collective communication primitives are the building blocks for coordinating work within pmap. They enable essential patterns like gradient aggregation in data parallelism, synchronization points, and data exchange for more advanced distributed algorithms. Understanding psum, pmean, and the role of axis_name is fundamental to scaling JAX computations effectively across multiple accelerators. While these operations introduce communication overhead, JAX and XLA work to optimize their execution on underlying hardware like GPUs and TPUs, often utilizing high-speed interconnects like NVLink or inter-chip interconnects (ICI).
Was this section helpful?
jax.lax collectives), JAX core contributors, 2024 - Official guide explaining JAX's SPMD programming model, pmap, axis_name, and the API for collective communication primitives.© 2026 ApX Machine LearningEngineered with