When using jax.pmap
to distribute computations, each device executes the same program on its local slice of the data (the SPMD model). However, distributed algorithms often require devices to communicate and synchronize. For example, in data-parallel training, each device calculates gradients based on its local data batch, but these gradients need to be aggregated (usually averaged) across all devices before updating the model parameters. This is where collective communication operations become essential.
Collective operations in JAX (jax.lax
collectives) allow arrays distributed across multiple devices (along the mapped axis defined by pmap
) to participate in a combined computation. The result of a collective operation is typically replicated back to all participating devices. These operations must be called inside the function being transformed by pmap
.
axis_name
A fundamental aspect of using collectives within pmap
is specifying which axis the communication should happen over. Since pmap
itself creates a new mapped axis representing the devices, you need to tell the collective operation to use this specific axis. This is done using the axis_name
argument, which must match the axis_name
provided to the enclosing pmap
function. This explicit naming prevents ambiguity, especially when dealing with nested pmap
calls or other complex transformations.
psum
, pmean
, pmax
, pmin
The most common collectives perform reduction operations across devices.
jax.lax.psum
(Parallel Sum)jax.lax.psum
calculates the element-wise sum of an array across all devices participating in the mapped axis. Each device contributes its local version of the array, and every device receives the identical resulting array containing the total sum.
Consider a scenario where each of four devices holds a scalar value:
Aggregation using
psum
across four devices along the named axis 'devices'. Each device starts with a local value, and afterpsum
, each device holds the total sum (2 + 3 + 1 + 4 = 10).
Here's how you might define a function using psum
to be used with pmap
:
import jax
import jax.numpy as jnp
# Function intended to be run via pmap
def sum_across_devices(local_value):
# 'batch_axis' must match the axis_name in the pmap call
total_sum = jax.lax.psum(local_value, axis_name='batch_axis')
# Each device now has the total sum.
# We could, for example, use it to scale local computations
return total_sum
# Conceptual usage (assuming 4 devices):
# values = jnp.arange(4.) # [0., 1., 2., 3.] -> one value per device
# pmapped_sum = jax.pmap(sum_across_devices, axis_name='batch_axis')
# result = pmapped_sum(values)
# # result would be DeviceArray([6., 6., 6., 6.], dtype=float32)
# # Each device gets 0+1+2+3 = 6
A common application is summing gradients in data-parallel training before the optimizer step.
jax.lax.pmean
(Parallel Mean)jax.lax.pmean
works similarly to psum
, but instead of returning the sum, it returns the element-wise mean of the array across all devices along the named axis. It's equivalent to performing a psum
and then dividing by the number of devices participating in that axis (which JAX tracks automatically).
import jax
import jax.numpy as jnp
# Function intended to be run via pmap
def average_across_devices(local_value):
# 'batch_axis' must match the axis_name in the pmap call
average_value = jax.lax.pmean(local_value, axis_name='batch_axis')
return average_value
# Conceptual usage (assuming 4 devices):
# values = jnp.arange(4.) # [0., 1., 2., 3.]
# pmapped_mean = jax.pmap(average_across_devices, axis_name='batch_axis')
# result = pmapped_mean(values)
# # result would be DeviceArray([1.5, 1.5, 1.5, 1.5], dtype=float32)
# # Each device gets (0+1+2+3)/4 = 1.5
pmean
is the standard operation for averaging gradients in data parallelism. Using pmean
directly is generally preferred over psum
followed by manual division, as it can sometimes be implemented more efficiently by the backend.
jax.lax.pmax
and jax.lax.pmin
(Parallel Max/Min)These collectives compute the element-wise maximum (pmax
) or minimum (pmin
) across devices along the named axis. Every device receives the identical resulting array containing the maximum or minimum values found across all participants.
Use cases include synchronizing flags (e.g., did any device encounter an error?), finding the maximum loss observed across batches, or other forms of distributed coordination.
import jax
import jax.numpy as jnp
# Function intended to be run via pmap
def get_max_value(local_value):
max_val = jax.lax.pmax(local_value, axis_name='data_split')
return max_val
# Conceptual usage (assuming 4 devices):
# values = jnp.array([2., 5., 1., 4.]) # One value per device
# pmapped_max = jax.pmap(get_max_value, axis_name='data_split')
# result = pmapped_max(values)
# # result would be DeviceArray([5., 5., 5., 5.], dtype=float32)
While reduction collectives are very common, JAX provides others for different communication patterns.
jax.lax.all_gather
: This operation gathers the input arrays from all devices and concatenates them along the mapped axis. Each device receives the full, concatenated array. This is useful if every device needs access to the data from all other devices, but be mindful that the output array size scales linearly with the number of devices, potentially consuming significant memory.
jax.lax.ppermute
: This is a more general collective allowing devices to exchange data based on a permutation rule specified by the permute
argument. It describes pairs of source/destination device indices for the data swap. ppermute
is fundamental for implementing more complex parallelism schemes, such as ring-allreduce or model parallelism communication patterns, although using it directly requires careful index management.
jax.lax.axis_index
: While not strictly a communication primitive, jax.lax.axis_index(axis_name)
is often used in conjunction with collectives. It returns the integer index (ID) of the current device within the specified mapped axis. This allows devices to behave differently based on their position within the group, which can be useful for specific communication patterns or workload balancing.
Collective communication primitives are the building blocks for coordinating work within pmap
. They enable essential patterns like gradient aggregation in data parallelism, synchronization points, and data exchange for more advanced distributed algorithms. Understanding psum
, pmean
, and the role of axis_name
is fundamental to scaling JAX computations effectively across multiple accelerators. While these operations introduce communication overhead, JAX and XLA work to optimize their execution on underlying hardware like GPUs and TPUs, often utilizing high-speed interconnects like NVLink or inter-chip interconnects (ICI).
© 2025 ApX Machine Learning