When you parallelize computation across multiple devices using jax.pmap
, each device executes the same function (the "Single Program" in SPMD) but operates on its unique slice of data (the "Multiple Data"). While this independent processing is powerful, you often need mechanisms for these parallel executions to communicate and aggregate results. For instance, in distributed machine learning training, each device might compute gradients based on its local data batch, but you need to combine these gradients across all devices before updating the model parameters. This is where collective operations come in.
Collective operations are specific functions, often found within the jax.lax
module (JAX's low-level API), designed to work inside pmap
. They orchestrate communication and computation across the devices participating in the parallel execution.
psum
and pmean
Two of the most frequently used collectives are lax.psum
(parallel sum) and lax.pmean
(parallel mean).
Imagine you have a value, say local_value
, computed independently on each device. You want to calculate the total sum of these local_value
s across all devices involved in the pmap
. You can achieve this using lax.psum
:
import jax
import jax.numpy as jnp
from jax import lax
# Assume we have 4 devices available for this example
# In a real scenario, JAX detects available devices
num_devices = 4
# Example function to be pmapped
def calculate_local_sum(x):
local_value = jnp.sum(x) # Each device computes a sum on its slice
# Sum 'local_value' across all devices participating in the pmap
# identified by the axis name 'devices'
total_sum = lax.psum(local_value, axis_name='devices')
return total_sum
# Create dummy data, sharded across devices
# Shape: (num_devices, data_per_device)
data = jnp.arange(num_devices * 3).reshape((num_devices, 3))
# data on device 0: [0, 1, 2] -> local_sum = 3
# data on device 1: [3, 4, 5] -> local_sum = 12
# data on device 2: [6, 7, 8] -> local_sum = 21
# data on device 3: [9, 10, 11] -> local_sum = 30
# Apply pmap, naming the mapped axis 'devices'
# The axis_name here must match the one used in lax.psum
pmapped_calculate_sum = jax.pmap(calculate_local_sum, axis_name='devices')
# Execute the pmapped function
result = pmapped_calculate_sum(data)
# The result will contain the *total* sum on each device
# total_sum = 3 + 12 + 21 + 30 = 66
print(result)
# Expected output (on 4 devices): [66 66 66 66]
Key aspects of lax.psum(value, axis_name)
:
value
computed locally on each device.axis_name
.axis_name
: This string identifier links the collective operation to the specific pmap
execution it belongs to. You define this name in the jax.pmap
call (axis_name='devices'
in the example) and use the same name within the collective (lax.psum(..., axis_name='devices')
). This ensures the sum happens over the correct group of devices, especially important if you ever nest pmap
calls.lax.psum
returns the same total sum to each device. This ensures all parallel executions have access to the aggregated result.lax.pmean
works analogously but computes the mean instead of the sum. It's equivalent to lax.psum(value, axis_name) / N
, where N is the number of devices along the mapped axis. This is very common for averaging gradients in data-parallel training.
import jax
import jax.numpy as jnp
from jax import lax
# Assume 4 devices
num_devices = 4
def calculate_local_mean(x):
local_value = jnp.mean(x) # Example local computation
# Calculate the mean of 'local_value' across all devices
global_mean = lax.pmean(local_value, axis_name='devices')
return global_mean
data = jnp.arange(num_devices * 3, dtype=jnp.float32).reshape((num_devices, 3))
# Local means: 1.0, 4.0, 7.0, 10.0
pmapped_calculate_mean = jax.pmap(calculate_local_mean, axis_name='devices')
result = pmapped_calculate_mean(data)
# The result will contain the mean of the local means on each device
# global_mean = (1.0 + 4.0 + 7.0 + 10.0) / 4 = 22.0 / 4 = 5.5
print(result)
# Expected output (on 4 devices): [5.5 5.5 5.5 5.5]
While psum
and pmean
are workhorses, jax.lax
provides other collectives:
lax.pmax(value, axis_name)
: Finds the maximum value across devices.lax.pmin(value, axis_name)
: Finds the minimum value across devices.lax.all_gather(value, axis_name)
: Gathers the value
from all devices and concatenates them along a new leading axis. Unlike psum
or pmean
which return a single aggregated scalar (if the input value
was scalar), all_gather
provides each device with the full collection of values from all other devices.Illustration of
lax.psum
. Local values from each device are summed, and the single total sum is made available back to all participating devices.
Remember these points when working with collectives:
jax.pmap
. Calling them outside this context will result in an error.axis_name
Consistency: The axis_name
string used in the collective function (lax.psum
, lax.pmean
, etc.) must match the axis_name
specified in the corresponding jax.pmap
call. This ensures the operation happens over the intended set of devices.psum
, pmean
).Collective operations are fundamental tools for writing effective multi-device programs with pmap
, enabling coordination and data aggregation essential for algorithms like distributed training. By mastering psum
, pmean
, and understanding the role of axis_name
, you can scale your JAX computations effectively across available hardware accelerators.
© 2025 ApX Machine Learning