When writing functions intended for parallel execution with pmap
, especially those involving collective communication operations like lax.psum
or lax.pmean
, it's important to specify across which group of devices the communication should occur. In basic examples, JAX often correctly infers that the collective operates over the single axis being mapped by pmap
. However, this implicit behavior can become ambiguous or lead to errors in more complex scenarios, such as nested pmap
calls or functions designed to be reusable in various parallel contexts.
To make the intent explicit and robust, JAX allows you to name the axis being mapped by pmap
. You provide this name using the axis_name
argument.
import jax
import jax.numpy as jnp
from jax import lax, pmap
# Define a function to be mapped
def scaled_sum(x):
# Perform some computation
scaled_x = x * 2.0
# Sum the results across the devices participating in the pmap
# Explicitly name the axis for the collective operation
total_sum = lax.psum(scaled_x, axis_name='devices')
return total_sum
# Number of devices
n_devices = jax.local_device_count()
print(f"Using {n_devices} devices.")
# Create some input data, sharded across devices
data = jnp.arange(n_devices, dtype=jnp.float32)
# Apply pmap, providing the axis_name 'devices'
# This name binds the string 'devices' to the mapped axis (axis 0 of the input)
parallel_computation = pmap(scaled_sum, axis_name='devices')
# Run the computation
result = parallel_computation(data)
print("Input data:", data)
# The result should be the same on all devices: sum(input * 2)
# Example: If data=[0., 1.], result=[2., 2.] because sum(0*2 + 1*2) = 2
# Example: If data=[0., 1., 2., 3.], result=[12., 12., 12., 12.] because sum(0*2 + 1*2 + 2*2 + 3*2) = 12
print("Result (same on all devices):", result)
# Verify the sum manually
expected_sum = jnp.sum(data * 2.0)
print("Expected sum:", expected_sum)
# Note: The output 'result' will be replicated across all devices.
# Accessing result[0] gives the computed total sum.
assert jnp.allclose(result[0], expected_sum)
In this example, pmap(scaled_sum, axis_name='devices')
applies the scaled_sum
function to each element of the input data
array distributed across the available devices. The first dimension of data
(of size n_devices
) is the mapped axis. We assign the name 'devices'
to this axis.
Inside scaled_sum
, the call lax.psum(scaled_x, axis_name='devices')
explicitly tells the collective operation to sum the scaled_x
values specifically along the axis named 'devices'
. Without axis_name
, lax.psum(scaled_x)
might still work in this simple case by implicitly summing over the mapped axis, but using the name removes any ambiguity.
lax.psum(..., axis_name='batch')
clearly indicates the sum is performed over the batch dimension that was parallelized.pmap
calls, each level can have a distinct axis_name
. Collective operations can then target a specific level of parallelism by referencing the appropriate name. For example, you might have pmap(..., axis_name='model_replicas')
nested inside pmap(..., axis_name='data_shards')
. A collective using axis_name='data_shards'
would operate across devices holding different data shards but the same model replica, while axis_name='model_replicas'
would operate across devices holding different model replicas for the same data shard (a pattern seen in some model parallelism techniques).'spatial'
can be used within any pmap
that maps an axis with that name, regardless of other named axes present.A standard use case for collectives in pmap
is averaging gradients during data-parallel training. Each device computes gradients for its local data shard, and these gradients need to be averaged across all devices before updating the model parameters.
import jax
import jax.numpy as jnp
from jax import lax, pmap, grad
# Dummy loss function (e.g., mean squared error)
def loss_fn(params, local_batch):
# Replace with actual model computation and loss
predictions = params['w'] * local_batch['x'] + params['b']
error = predictions - local_batch['y']
return jnp.mean(error**2)
# Function to compute gradients on one device
def compute_gradients(params, local_batch):
return grad(loss_fn)(params, local_batch)
# Update step including gradient averaging
def parallel_update_step(params, sharded_batch):
# Compute gradients locally on each device
local_grads = compute_gradients(params, sharded_batch)
# Average gradients across devices using the named axis 'data_parallel_axis'
# Using pmean directly averages. Using psum and dividing later also works.
avg_grads = lax.pmean(local_grads, axis_name='data_parallel_axis')
# Simple gradient descent update (could be replaced with Adam, etc.)
learning_rate = 0.01
# Note: In a real scenario, use jax.tree_map for arbitrary pytree structures
new_params = {
'w': params['w'] - learning_rate * avg_grads['w'],
'b': params['b'] - learning_rate * avg_grads['b']
}
return new_params
# Get number of devices
n_devices = jax.local_device_count()
# Initialize dummy parameters (same on all devices initially)
params = {'w': jnp.ones(()), 'b': jnp.zeros(())}
# Create dummy data sharded across devices
# Shape: (n_devices, batch_per_device, ...)
xs = jnp.arange(n_devices * 4, dtype=jnp.float32).reshape((n_devices, 4, 1))
ys = (xs * 2.0) + 0.5 # True w=2.0, b=0.5
sharded_batch = {'x': xs, 'y': ys}
# Define the parallel update function using pmap with an axis name
p_update_step = pmap(parallel_update_step, axis_name='data_parallel_axis')
# Execute one update step
new_params = p_update_step(params, sharded_batch)
# The parameters are updated based on the average gradient across all devices
# new_params will be replicated across devices. Access params on one device:
print("Original params:", params)
print("Updated params (on device 0):", jax.tree_map(lambda x: x[0], new_params))
In this training step, pmap
maps the parallel_update_step
function over the first axis of sharded_batch
, which we name 'data_parallel_axis'
. Inside the function, lax.pmean(local_grads, axis_name='data_parallel_axis')
ensures that the gradient averaging happens specifically across the devices participating in this data-parallel computation.
While JAX might correctly infer the axis in simple cases, explicitly naming it with axis_name
is a highly recommended practice. It makes your distributed computations significantly clearer, less prone to subtle errors, and easier to maintain and compose, especially as your parallelism strategies become more sophisticated.
© 2025 ApX Machine Learning