JAX's power truly shines when its function transformations are combined. While pmap
provides the mechanism for distributing computations across multiple devices, it rarely operates in isolation. The computations running on each device often need to be compiled for speed (jit
), differentiated for optimization (grad
), or vectorized for efficiency (vmap
). Thankfully, JAX transformations are designed to be composable, allowing you to layer these capabilities naturally.
This section explores how pmap
interacts with jit
, grad
, and vmap
, enabling you to build sophisticated, high-performance distributed programs.
pmap
and jit
You might wonder if you need to explicitly combine pmap
and jit
, perhaps by writing pmap(jit(my_function))
. The answer is generally no, because pmap
already includes JIT compilation.
When you apply pmap
to a function, JAX traces the function (similar to jit
) and compiles it using XLA for the target devices (CPU, GPU, or TPU). This compilation is essential for pmap
's performance, as it optimizes the code specifically for parallel execution across devices.
import jax
import jax.numpy as jnp
import numpy as np
import time
# Assume we have 2 devices available for this example
num_devices = 2
devices = jax.local_devices()[:num_devices]
print(f"Using {len(devices)} devices: {devices}")
# A simple function that could benefit from JIT
def complex_computation(x):
y = jnp.sin(x) * jnp.cos(x)
z = jnp.tanh(y) + jnp.sqrt(jnp.abs(x))
return z * 2.0
# Apply pmap directly
pmap_complex_computation = jax.pmap(complex_computation)
# Create data sharded across devices
data = np.arange(8.0).reshape(num_devices, -1) # Shape (2, 4)
sharded_data = jax.device_put(data, devices)
# Run the pmap'd function (includes JIT compilation)
start_time = time.time()
result = pmap_complex_computation(sharded_data)
result.block_until_ready() # Ensure computation finishes before timing
end_time = time.time()
print(f"pmap execution time: {end_time - start_time:.6f} seconds")
print("Result shape:", result.shape)
# print("Result:\n", result) # Uncomment to see the result
# Explicitly JITting first doesn't typically add value
# jit_then_pmap = jax.pmap(jax.jit(complex_computation))
# start_time = time.time()
# result_jit_then_pmap = jit_then_pmap(sharded_data)
# result_jit_then_pmap.block_until_ready()
# end_time = time.time()
# print(f"jit -> pmap execution time: {end_time - start_time:.6f} seconds")
Executing this code demonstrates that pmap
handles the compilation. While jax.pmap(jax.jit(f))
is valid, it usually offers no advantage over jax.pmap(f)
, as pmap
performs its own JIT compilation tailored for multi-device execution. The primary scenario where you might JIT a function separately is if it's a component used within the main function passed to pmap
, and you want to control its compilation independently.
pmap
and grad
: Distributed Gradient ComputationA fundamental use case for pmap
in machine learning is data parallelism: training a model on large datasets by distributing batches of data across multiple devices. This requires calculating gradients based on the data shard on each device and then aggregating these gradients. Combining pmap
with grad
(or value_and_grad
) achieves exactly this.
The typical pattern involves:
jax.grad
or jax.value_and_grad
to create a function that computes the gradient of the loss with respect to the parameters.pmap
to this gradient-calculating function. pmap
will execute the gradient calculation on each device using its local data shard.jax.lax.pmean
, inside the pmap
ped function to average the gradients calculated across all devices. This ensures all devices get the same, globally averaged gradient to update the model parameters consistently.Let's illustrate with a simplified example:
import jax
import jax.numpy as jnp
import numpy as np
# Assume 2 devices
num_devices = 2
devices = jax.local_devices()[:num_devices]
# Simple model and loss function
def predict(params, x):
# A simple linear model: y = w*x + b
w, b = params
return w * x + b
def loss_fn(params, x_batch, y_batch):
predictions = predict(params, x_batch)
error = predictions - y_batch
return jnp.mean(error**2) # Mean squared error
# Function to compute value (loss) and gradient
value_and_grad_fn = jax.value_and_grad(loss_fn)
# Function to be pmap'd: computes gradients per device and averages them
def parallel_update_step(params, x_shards, y_shards):
# Calculate gradients locally on each device's data shard
loss, grads = value_and_grad_fn(params, x_shards, y_shards)
# Average gradients across all devices
# 'axis_name' must match the one provided in pmap
avg_grads = jax.lax.pmean(grads, axis_name='devices')
# Optionally average loss too (for logging)
avg_loss = jax.lax.pmean(loss, axis_name='devices')
return avg_loss, avg_grads
# Apply pmap, specifying the mapped axis and the collective axis name
# Parameters are replicated, data is sharded along axis 0
pmap_update_step = jax.pmap(
parallel_update_step,
axis_name='devices', # Name for collective operations
in_axes=(None, 0, 0), # Replicate params, map x and y along axis 0
out_axes=(None, None) # Return averaged loss/grads (same on all devices)
)
# Example parameters and data
params = (jnp.array(2.0), jnp.array(1.0)) # (w, b)
# Create data for 2 devices, 4 samples each
x_data = np.arange(8.0).reshape(num_devices, -1) # Shape (2, 4)
y_data = (3 * x_data + 2 + np.random.randn(*x_data.shape) * 0.5) # y = 3x + 2 + noise
# Put data onto devices
sharded_x = jax.device_put(x_data, devices)
sharded_y = jax.device_put(y_data, devices)
# Execute the parallel gradient computation
avg_loss, avg_grads = pmap_update_step(params, sharded_x, sharded_y)
print(f"Average Loss across devices: {avg_loss:.4f}")
print(f"Averaged Gradients (dw, db): ({avg_grads[0]:.4f}, {avg_grads[1]:.4f})")
# avg_grads can now be used to update the parameters 'params'
# Note: The returned avg_loss and avg_grads are regular JAX arrays,
# not sharded, because we used out_axes=None after pmean.
In this example:
value_and_grad_fn
creates the function to compute loss and gradients.parallel_update_step
wraps this, adding the jax.lax.pmean
collective to average gradients (and loss) across devices participating in the pmap
. The axis_name='devices'
links the pmean
operation to the pmap
context.jax.pmap
is configured with in_axes=(None, 0, 0)
meaning:
params
are not mapped (None
), so each device gets a full copy (replication).x_shards
and y_shards
are mapped along their first axis (0
), distributing the data.out_axes=(None, None)
ensures the averaged results (which are identical across devices after pmean
) are returned as regular, unreplicated JAX arrays.This pmap(value_and_grad(...))
pattern, combined with collectives, is the foundation of distributed data-parallel training in JAX.
pmap
and vmap
Combining pmap
and vmap
is less common than combining pmap
with grad
or jit
, primarily because pmap
itself performs a form of mapping across devices. However, vmap
can still be useful inside the function being processed by pmap
.
Recall that pmap
implements SPMD (Single Program, Multiple Data): the same function runs on each device, but on different data slices. vmap
vectorizes operations within a single execution trace.
You might use vmap
inside a pmap
ped function if you need an additional level of vectorization that operates independently on each device's data shard.
Consider a scenario where each device processes a batch of images (pmap
), and for each image, you want to apply the same operation to multiple patches extracted from it (vmap
).
import jax
import jax.numpy as jnp
import numpy as np
# Assume 2 devices
num_devices = 2
devices = jax.local_devices()[:num_devices]
# Function operating on a single item (e.g., an image patch)
def process_item(item):
return jnp.tanh(item) * 2.0
# Function using vmap to process multiple items (e.g., patches within an image)
# This function runs *on each device*
def process_batch_of_items(batch):
# Use vmap for vectorization *within* the device's work
vectorized_processor = jax.vmap(process_item)
return vectorized_processor(batch)
# pmap distributes batches across devices
pmap_process_batches = jax.pmap(
process_batch_of_items,
in_axes=0 # Map the first axis of the input data across devices
)
# Example: 2 devices, each processing a batch of 3 items of size 4
# Total data shape: (num_devices, batch_per_device, item_size) = (2, 3, 4)
data = np.arange(2 * 3 * 4.0).reshape(num_devices, 3, 4)
sharded_data = jax.device_put(data, devices)
# Execute: pmap distributes the (3, 4) batches,
# vmap inside processes the 3 items in parallel on each device.
result = pmap_process_batches(sharded_data)
result.block_until_ready()
print("Input data shape per device:", data.shape[1:]) # (3, 4)
print("Output result shape:", result.shape) # (2, 3, 4)
# Each device gets (3, 4), vmap operates over axis 0 (the 3 items),
# pmap concatenates results along axis 0 (the 2 devices)
Here, pmap
distributes the outer dimension (size 2) across devices. On each device, process_batch_of_items
receives a slice of shape (3, 4)
. Inside this function, vmap(process_item)
automatically vectorizes the process_item
function over the leading axis (size 3) of the data it receives.
While you can technically write pmap(vmap(f))
, it often overlaps conceptually with what pmap(f, in_axes=...)
already achieves. Using vmap
inside the function passed to pmap
is generally the more intuitive and common way to introduce further vectorization within each device's parallel computation.
The typical order when combining these transformations reflects their roles:
vmap
, grad
, jit
): These define the core computation logic. vmap
vectorizes, grad
computes derivatives, and jit
(often implicitly handled by pmap
or applied to inner functions) compiles for single-device efficiency.pmap
): This orchestrates the execution across multiple devices, handling data distribution, launching the (potentially already transformed) inner function on each device, and managing collective communication.Therefore, patterns like pmap(grad(jit(my_loss)))
or pmap(jit(vmap(my_kernel)))
(where the jit
might be implicit in pmap
) are conceptually common, with pmap
acting as the top-level distributor.
By understanding how to compose pmap
with jit
, grad
, and vmap
, you can leverage the full potential of JAX, creating highly efficient programs that scale effectively across modern hardware accelerators. This composability is a defining feature of JAX, enabling complex workflows like distributed model training to be expressed concisely and efficiently.
© 2025 ApX Machine Learning