jax.jitjitjitjitgradjax.gradgrad of grad)jax.value_and_grad)vmapjax.vmapin_axes, out_axes)vmapvmap with jit and gradvmappmapjax.pmapin_axes, out_axes)lax.psum, lax.pmean, etc.)pmap with other Transformationspmapped Functionspmap with other TransformationsJAX's power truly shines when its function transformations are combined. While pmap provides the mechanism for distributing computations across multiple devices, it rarely operates in isolation. The computations running on each device often need to be compiled for speed (jit), differentiated for optimization (grad), or vectorized for efficiency (vmap). Thankfully, JAX transformations are designed to be composable, allowing you to layer these capabilities naturally.
This section shows how pmap interacts with jit, grad, and vmap, enabling you to build sophisticated, high-performance distributed programs.
pmap and jitYou might wonder if you need to explicitly combine pmap and jit, perhaps by writing pmap(jit(my_function)). The answer is generally no, because pmap already includes JIT compilation.
When you apply pmap to a function, JAX traces the function (similar to jit) and compiles it using XLA for the target devices (CPU, GPU, or TPU). This compilation is essential for pmap's performance, as it optimizes the code specifically for parallel execution across devices.
import jax
import jax.numpy as jnp
import numpy as np
import time
# Assume we have 2 devices available for this example
num_devices = 2
devices = jax.local_devices()[:num_devices]
print(f"Using {len(devices)} devices: {devices}")
# A simple function that could benefit from JIT
def complex_computation(x):
y = jnp.sin(x) * jnp.cos(x)
z = jnp.tanh(y) + jnp.sqrt(jnp.abs(x))
return z * 2.0
# Apply pmap directly
pmap_complex_computation = jax.pmap(complex_computation)
# Create data sharded across devices
data = np.arange(8.0).reshape(num_devices, -1) # Shape (2, 4)
sharded_data = jax.device_put(data, devices)
# Run the pmap'd function (includes JIT compilation)
start_time = time.time()
result = pmap_complex_computation(sharded_data)
result.block_until_ready() # Ensure computation finishes before timing
end_time = time.time()
print(f"pmap execution time: {end_time - start_time:.6f} seconds")
print("Result shape:", result.shape)
# print("Result:\n", result) # Uncomment to see the result
# Explicitly JITting first doesn't typically add value
# jit_then_pmap = jax.pmap(jax.jit(complex_computation))
# start_time = time.time()
# result_jit_then_pmap = jit_then_pmap(sharded_data)
# result_jit_then_pmap.block_until_ready()
# end_time = time.time()
# print(f"jit -> pmap execution time: {end_time - start_time:.6f} seconds")
Executing this code demonstrates that pmap handles the compilation. While jax.pmap(jax.jit(f)) is valid, it usually offers no advantage over jax.pmap(f), as pmap performs its own JIT compilation tailored for multi-device execution. The primary scenario where you might JIT a function separately is if it's a component used within the main function passed to pmap, and you want to control its compilation independently.
pmap and grad: Distributed Gradient ComputationA fundamental use case for pmap in machine learning is data parallelism: training a model on large datasets by distributing batches of data across multiple devices. This requires calculating gradients based on the data shard on each device and then aggregating these gradients. Combining pmap with grad (or value_and_grad) achieves exactly this.
The typical pattern involves:
jax.grad or jax.value_and_grad to create a function that computes the gradient of the loss with respect to the parameters.pmap to this gradient-calculating function. pmap will execute the gradient calculation on each device using its local data shard.jax.lax.pmean, inside the pmapped function to average the gradients calculated across all devices. This ensures all devices get the same, globally averaged gradient to update the model parameters consistently.Let's illustrate with a simplified example:
import jax
import jax.numpy as jnp
import numpy as np
# Assume 2 devices
num_devices = 2
devices = jax.local_devices()[:num_devices]
# Simple model and loss function
def predict(params, x):
# A simple linear model: y = w*x + b
w, b = params
return w * x + b
def loss_fn(params, x_batch, y_batch):
predictions = predict(params, x_batch)
error = predictions - y_batch
return jnp.mean(error**2) # Mean squared error
# Function to compute value (loss) and gradient
value_and_grad_fn = jax.value_and_grad(loss_fn)
# Function to be pmap'd: computes gradients per device and averages them
def parallel_update_step(params, x_shards, y_shards):
# Calculate gradients locally on each device's data shard
loss, grads = value_and_grad_fn(params, x_shards, y_shards)
# Average gradients across all devices
# 'axis_name' must match the one provided in pmap
avg_grads = jax.lax.pmean(grads, axis_name='devices')
# Optionally average loss too (for logging)
avg_loss = jax.lax.pmean(loss, axis_name='devices')
return avg_loss, avg_grads
# Apply pmap, specifying the mapped axis and the collective axis name
# Parameters are replicated, data is sharded along axis 0
pmap_update_step = jax.pmap(
parallel_update_step,
axis_name='devices', # Name for collective operations
in_axes=(None, 0, 0), # Replicate params, map x and y along axis 0
out_axes=(None, None) # Return averaged loss/grads (same on all devices)
)
# Example parameters and data
params = (jnp.array(2.0), jnp.array(1.0)) # (w, b)
# Create data for 2 devices, 4 samples each
x_data = np.arange(8.0).reshape(num_devices, -1) # Shape (2, 4)
y_data = (3 * x_data + 2 + np.random.randn(*x_data.shape) * 0.5) # y = 3x + 2 + noise
# Put data onto devices
sharded_x = jax.device_put(x_data, devices)
sharded_y = jax.device_put(y_data, devices)
# Execute the parallel gradient computation
avg_loss, avg_grads = pmap_update_step(params, sharded_x, sharded_y)
print(f"Average Loss across devices: {avg_loss:.4f}")
print(f"Averaged Gradients (dw, db): ({avg_grads[0]:.4f}, {avg_grads[1]:.4f})")
# avg_grads can now be used to update the parameters 'params'
# Note: The returned avg_loss and avg_grads are regular JAX arrays,
# not sharded, because we used out_axes=None after pmean.
In this example:
value_and_grad_fn creates the function to compute loss and gradients.parallel_update_step wraps this, adding the jax.lax.pmean collective to average gradients (and loss) across devices participating in the pmap. The axis_name='devices' links the pmean operation to the pmap context.jax.pmap is configured with in_axes=(None, 0, 0) meaning:
params are not mapped (None), so each device gets a full copy (replication).x_shards and y_shards are mapped along their first axis (0), distributing the data.out_axes=(None, None) ensures the averaged results (which are identical across devices after pmean) are returned as regular, unreplicated JAX arrays.This pmap(value_and_grad(...)) pattern, combined with collectives, is the foundation of distributed data-parallel training in JAX.
pmap and vmapCombining pmap and vmap is less common than combining pmap with grad or jit, primarily because pmap itself performs a form of mapping across devices. However, vmap can still be useful inside the function being processed by pmap.
Recall that pmap implements SPMD (Single Program, Multiple Data): the same function runs on each device, but on different data slices. vmap vectorizes operations within a single execution trace.
You might use vmap inside a pmapped function if you need an additional level of vectorization that operates independently on each device's data shard.
Consider a scenario where each device processes a batch of images (pmap), and for each image, you want to apply the same operation to multiple patches extracted from it (vmap).
import jax
import jax.numpy as jnp
import numpy as np
# Assume 2 devices
num_devices = 2
devices = jax.local_devices()[:num_devices]
# Function operating on a single item (e.g., an image patch)
def process_item(item):
return jnp.tanh(item) * 2.0
# Function using vmap to process multiple items (e.g., patches within an image)
# This function runs *on each device*
def process_batch_of_items(batch):
# Use vmap for vectorization *within* the device's work
vectorized_processor = jax.vmap(process_item)
return vectorized_processor(batch)
# pmap distributes batches across devices
pmap_process_batches = jax.pmap(
process_batch_of_items,
in_axes=0 # Map the first axis of the input data across devices
)
# Example: 2 devices, each processing a batch of 3 items of size 4
# Total data shape: (num_devices, batch_per_device, item_size) = (2, 3, 4)
data = np.arange(2 * 3 * 4.0).reshape(num_devices, 3, 4)
sharded_data = jax.device_put(data, devices)
# Execute: pmap distributes the (3, 4) batches,
# vmap inside processes the 3 items in parallel on each device.
result = pmap_process_batches(sharded_data)
result.block_until_ready()
print("Input data shape per device:", data.shape[1:]) # (3, 4)
print("Output result shape:", result.shape) # (2, 3, 4)
# Each device gets (3, 4), vmap operates over axis 0 (the 3 items),
# pmap concatenates results along axis 0 (the 2 devices)
Here, pmap distributes the outer dimension (size 2) across devices. On each device, process_batch_of_items receives a slice of shape (3, 4). Inside this function, vmap(process_item) automatically vectorizes the process_item function over the leading axis (size 3) of the data it receives.
While you can technically write pmap(vmap(f)), it often overlaps with what pmap(f, in_axes=...) already achieves. Using vmap inside the function passed to pmap is generally the more intuitive and common way to introduce further vectorization within each device's parallel computation.
The typical order when combining these transformations reflects their roles:
vmap, grad, jit): These define the core computation logic. vmap vectorizes, grad computes derivatives, and jit (often implicitly handled by pmap or applied to inner functions) compiles for single-device efficiency.pmap): This orchestrates the execution across multiple devices, handling data distribution, launching the (potentially already transformed) inner function on each device, and managing collective communication.Therefore, patterns like pmap(grad(jit(my_loss))) or pmap(jit(vmap(my_kernel))) (where the jit might be implicit in pmap) are common, with pmap acting as the top-level distributor.
By understanding how to compose pmap with jit, grad, and vmap, you can leverage the full potential of JAX, creating highly efficient programs that scale effectively across modern hardware accelerators. This composability is a defining feature of JAX, enabling complex workflows like distributed model training to be expressed concisely and efficiently.
Was this section helpful?
pmap, jit, grad, and vmap, and their composability.pmap and collective operations.© 2026 ApX Machine LearningEngineered with