While pmap is powerful for distributing computations across a single dimension of devices, complex parallelism strategies often require mapping computations over multi-dimensional grids of accelerators. Nesting pmap calls provides a solution for this requirement, allowing you to build sophisticated data and model parallelism configurations.
pmapImagine you have a grid of devices, perhaps 4 GPUs arranged logically as a 2x2 grid. You might want to apply data parallelism across the rows and model parallelism across the columns. This requires mapping your computation along two different axes simultaneously. JAX allows you to achieve this by nesting pmap calls.
A function decorated with pmap can itself call another function decorated with pmap.
import jax
import jax.numpy as jnp
from jax.experimental import maps # Needed for mesh definitions, sometimes useful
# Assume 4 devices are available
num_devices = jax.local_device_count()
if num_devices < 4:
print(f"Warning: Need at least 4 devices for this example, found {num_devices}")
# Fallback for simpler execution if not enough devices
devices = jax.local_devices()[:1] * 4 # Use first device 4 times if needed
else:
devices = jax.local_devices()[:4]
# Reshape devices into a 2x2 grid (logical representation)
device_mesh = jax.device_put(jnp.arange(num_devices).reshape(2, 2), devices).devices
print("Logical Device Mesh:")
print(device_mesh)
# Inner function: operates on data specific to one 'model' shard
# We map this function across the 'model' axis (columns)
@jax.pmap(axis_name='model')
def inner_op(x_model_shard, params_model_shard):
# Example: Simple computation within a model shard
# Sum across the 'model' axis (all devices in this inner pmap call)
sum_across_models = jax.lax.psum(x_model_shard * params_model_shard, axis_name='model')
return sum_across_models * 2 # Some arbitrary operation
# Outer function: operates on data specific to one 'data' batch shard
# We map this function across the 'data' axis (rows)
@jax.pmap(axis_name='data')
def outer_op(x_data_shard, params_data_shard):
# x_data_shard arrives here shaped (num_model_shards, ...)
# params_data_shard arrives here shaped (num_model_shards, ...)
# Call the inner pmap over the 'model' axis
result = inner_op(x_data_shard, params_data_shard) # Implicitly maps over leading dim
# Example: Sum results across the 'data' axis (all devices in this outer pmap call)
final_result = jax.lax.psum(result, axis_name='data')
return final_result
# Prepare input data and parameters, sharded across both axes
# Shape: (num_data_shards, num_model_shards, ...)
data = jnp.arange(16.).reshape(2, 2, 2, 2) # (data_axis=2, model_axis=2, feature1=2, feature2=2)
params = jnp.ones(16.).reshape(2, 2, 2, 2) * 0.5
# Execute the nested pmap
# JAX automatically handles placing shards onto the corresponding devices
# based on the nested structure.
output = outer_op(data, params)
print("\nInput Data Shape:", data.shape)
print("Output Shape:", output.shape) # Output will be replicated across all devices
print("Output (one replica):")
print(output[0])
# Example verification (manual calculation for this specific op)
# For each inner_op call (fixed data shard, varying model shard):
# shard_0_0: (0,1,2,3) * 0.5 = (0, 0.5, 1, 1.5) -> psum = (0+0.5+1+1.5) = 3.0 -> *2 = 6.0
# shard_0_1: (4,5,6,7) * 0.5 = (2, 2.5, 3, 3.5) -> psum = (2+2.5+3+3.5) = 11.0 -> *2 = 22.0
# Inner Results for data_shard 0: (6.0, 22.0) <- Replicated across model axis
#
# shard_1_0: (8,9,10,11) * 0.5 = (4, 4.5, 5, 5.5) -> psum = (4+4.5+5+5.5) = 19.0 -> *2 = 38.0
# shard_1_1: (12,13,14,15)*0.5 = (6, 6.5, 7, 7.5) -> psum = (6+6.5+7+7.5) = 27.0 -> *2 = 54.0
# Inner Results for data_shard 1: (38.0, 54.0) <- Replicated across model axis
#
# Outer op psum over 'data' axis:
# Axis 0: psum(6.0, 38.0) = 44.0
# Axis 1: psum(22.0, 54.0) = 76.0
# Final result should be replicated across all devices, shape (2,2,2) e.g., [[44, 44],[44, 44]], [[76, 76],[76,76]] per original feature dims.
# Wait, the example code reshaped inner result. Let's re-trace.
# inner_op input: (2, 2) features per device pair along model axis.
# inner_op(x[0,0], p[0,0]) = psum([0,1]*0.5 + [2,3]*0.5, axis='model') = psum([0, 0.5] + [1, 1.5], axis='model') = psum([1, 2]) = 3.0? NO. psum is elementwise?
# Let's re-read psum: Sums array elements *across devices*.
# Okay, let's assume 4 devices. Mesh is [[0,1],[2,3]]
# outer_op called on devices [0,2] (data axis 0) and [1,3] (data axis 1) ?? No, outer_op maps the *function* across devices.
# Let's simplify the input data for clarity. Assume 1 feature dim.
# data = jnp.arange(4.).reshape(2, 2) # (data_axis=2, model_axis=2)
# params = jnp.ones(4.).reshape(2, 2) * 0.5
#
# Device mapping (example):
# Device 0 gets data[0,0]=0, params[0,0]=0.5
# Device 1 gets data[0,1]=1, params[0,1]=0.5
# Device 2 gets data[1,0]=2, params[1,0]=0.5
# Device 3 gets data[1,1]=3, params[1,1]=0.5
#
# outer_op is called, mapping over data axis.
# - Call for data axis index 0 runs on devices 0 and 1. It gets x_data_shard = (data[0,0], data[0,1]) = (0, 1) and params_data_shard = (params[0,0], params[0,1]) = (0.5, 0.5)
# - Call for data axis index 1 runs on devices 2 and 3. It gets x_data_shard = (data[1,0], data[1,1]) = (2, 3) and params_data_shard = (params[1,0], params[1,1]) = (0.5, 0.5)
#
# Inside outer_op (e.g., for data axis index 0 on devices 0, 1):
# Calls inner_op(x_data_shard=(0,1), params_data_shard=(0.5, 0.5))
# inner_op is pmapped over 'model' axis.
# - Call for model axis index 0 runs on device 0. Gets x_model_shard=0, params_model_shard=0.5. Computes 0 * 0.5 = 0.
# - Call for model axis index 1 runs on device 1. Gets x_model_shard=1, params_model_shard=0.5. Computes 1 * 0.5 = 0.5.
# Collective psum over 'model' axis (devices 0, 1): psum(0, 0.5) = 0.5.
# Result on both devices 0 and 1 is 0.5 * 2 = 1.0.
#
# Inside outer_op (e.g., for data axis index 1 on devices 2, 3):
# Calls inner_op(x_data_shard=(2,3), params_data_shard=(0.5, 0.5))
# inner_op is pmapped over 'model' axis.
# - Call for model axis index 0 runs on device 2. Gets x_model_shard=2, params_model_shard=0.5. Computes 2 * 0.5 = 1.0.
# - Call for model axis index 1 runs on device 3. Gets x_model_shard=3, params_model_shard=0.5. Computes 3 * 0.5 = 1.5.
# Collective psum over 'model' axis (devices 2, 3): psum(1.0, 1.5) = 2.5.
# Result on both devices 2 and 3 is 2.5 * 2 = 5.0.
#
# Back in outer_op:
# - For data axis index 0 (devices 0, 1), result = 1.0.
# - For data axis index 1 (devices 2, 3), result = 5.0.
# Collective psum over 'data' axis (all devices 0, 1, 2, 3): This requires care. Which devices participate in which collective?
# The 'data' axis in the *outer* pmap logically groups devices. Let's assume devices [[0,1],[2,3]]. The 'data' axis runs over the rows.
# So, devices {0,1} form one group for the 'data' psum, and devices {2,3} form another. But that doesn't seem right.
# Let's consult JAX docs on nested pmap axes... `pmap` nests like `vmap`. The axis name scopes locally. The outer `pmap` defines the 'data' axis *across all participating devices*. The inner `pmap` defines the 'model' axis *within the subset of devices processing a single outer iteration*.
# Let's rethink the device assignment/mesh. `pmap` by default uses `jax.devices()`. Let's assume 4 devices linearly [0, 1, 2, 3].
# `outer_op` maps over the first dimension of the input (size 2).
# - Iteration 0 (data axis index 0) runs on devices [0, 1]. Receives data[0]=((0,1),(2,3)) and params[0]=((0.5,0.5),(0.5,0.5)). x_data_shard = data[0], params_data_shard = params[0] on these devices.
# - Iteration 1 (data axis index 1) runs on devices [2, 3]. Receives data[1]=((8,9),(10,11)) and params[1]=((0.5,0.5),(0.5,0.5)). x_data_shard = data[1], params_data_shard = params[1] on these devices.
# Inside outer_op, iteration 0 (devices 0, 1):
# Calls `inner_op(x_data_shard, params_data_shard)`. `inner_op` is mapped over the first dimension of *its* input (`x_data_shard`, size 2, corresponding to 'model' axis).
# - Inner iteration 0 (model axis index 0) runs on device 0. Gets x_model_shard=x_data_shard[0]=(0,1), params_model_shard=params_data_shard[0]=(0.5, 0.5). Computes x*p = (0, 0.5).
# - Inner iteration 1 (model axis index 1) runs on device 1. Gets x_model_shard=x_data_shard[1]=(2,3), params_model_shard=params_data_shard[1]=(0.5, 0.5). Computes x*p = (1, 1.5).
# `jax.lax.psum(..., axis_name='model')` sums across the devices participating in this `inner_op` call (devices 0, 1).
# `psum((0, 0.5), (1, 1.5))` -> `(0+1, 0.5+1.5)` = `(1, 2)`. This result `(1,2)` exists on both device 0 and device 1.
# `result = (1, 2) * 2 = (2, 4)`. This exists on devices 0 and 1.
# Inside outer_op, iteration 1 (devices 2, 3):
# Calls `inner_op(x_data_shard, params_data_shard)`. `inner_op` is mapped over the first dimension of *its* input (`x_data_shard`, size 2, corresponding to 'model' axis).
# - Inner iteration 0 (model axis index 0) runs on device 2. Gets x_model_shard=x_data_shard[0]=(8,9), params_model_shard=params_data_shard[0]=(0.5, 0.5). Computes x*p = (4, 4.5).
# - Inner iteration 1 (model axis index 1) runs on device 3. Gets x_model_shard=x_data_shard[1]=(10,11), params_model_shard=params_data_shard[1]=(0.5, 0.5). Computes x*p = (5, 5.5).
# `jax.lax.psum(..., axis_name='model')` sums across the devices participating in this `inner_op` call (devices 2, 3).
# `psum((4, 4.5), (5, 5.5))` -> `(4+5, 4.5+5.5)` = `(9, 10)`. This result `(9,10)` exists on both device 2 and device 3.
# `result = (9, 10) * 2 = (18, 20)`. This exists on devices 2 and 3.
# Back in `outer_op`:
# `jax.lax.psum(result, axis_name='data')` sums across devices participating in the `outer_op` call (all devices 0, 1, 2, 3).
# The value `result` on device 0 is (2, 4). On device 1 is (2, 4). On device 2 is (18, 20). On device 3 is (18, 20).
# `psum((2,4), (2,4), (18,20), (18,20))` -> `(2+2+18+18, 4+4+20+20)` = `(40, 48)`.
# This final result `(40, 48)` should be present on all devices 0, 1, 2, 3.
# Let's re-run the code mentally with the simpler input:
# data = jnp.arange(4.).reshape(2, 2)
# params = jnp.ones(4.).reshape(2, 2) * 0.5
# Device 0: data=0, param=0.5. outer_iter=0, inner_iter=0. x*p=0.
# Device 1: data=1, param=0.5. outer_iter=0, inner_iter=1. x*p=0.5.
# Device 2: data=2, param=0.5. outer_iter=1, inner_iter=0. x*p=1.0.
# Device 3: data=3, param=0.5. outer_iter=1, inner_iter=1. x*p=1.5.
# outer_iter=0 (dev 0,1): inner_op receives x=(0,1), p=(0.5,0.5).
# inner_iter=0 (dev 0): receives x=0, p=0.5 -> computes 0.
# inner_iter=1 (dev 1): receives x=1, p=0.5 -> computes 0.5.
# inner_psum over 'model' (dev 0,1): psum(0, 0.5) = 0.5. Result on dev 0,1 = 0.5 * 2 = 1.0.
# outer_iter=1 (dev 2,3): inner_op receives x=(2,3), p=(0.5,0.5).
# inner_iter=0 (dev 2): receives x=2, p=0.5 -> computes 1.0.
# inner_iter=1 (dev 3): receives x=3, p=0.5 -> computes 1.5.
# inner_psum over 'model' (dev 2,3): psum(1.0, 1.5) = 2.5. Result on dev 2,3 = 2.5 * 2 = 5.0.
# outer_psum over 'data' (dev 0,1,2,3): psum(1.0, 1.0, 5.0, 5.0) = 12.0.
# Final output should be 12.0 on all devices.
# Let's reset the example code slightly to use this simpler data.
data_simple = jnp.arange(4.).reshape(2, 2) # (data_axis=2, model_axis=2)
params_simple = jnp.ones(4.).reshape(2, 2) * 0.5
# Rerun with simpler data
output_simple = outer_op(data_simple, params_simple)
print("\n--- Simple Example ---")
print("Input Data Shape:", data_simple.shape)
print("Output Shape:", output_simple.shape)
print("Output (one replica):")
print(output_simple[0]) # Expect 12.0
Important points about nested pmap:
pmap are local to that specific call and the devices participating in it for that outer iteration. An outer axis_name refers to collectives across all devices involved in the outer pmap.pmap axes. In the example above, data and params had shape (2, 2, ...), mapping to a 2x2 logical device grid. The first dimension corresponds to the data axis of the outer pmap, and the second dimension corresponds to the model axis of the inner pmap.psum must specify the axis_name they operate over. This tells JAX which set of devices should participate in the communication (e.g., sum across the 'model' axis within an inner_op call, or sum across the 'data' axis within the outer_op call).pmapWhile nesting pmap allows mapping computations onto multi-dimensional device grids, pmap itself follows the SPMD principle: the same program runs everywhere, but operates on different slices of data. The partitioning is implicitly defined by how the input arrays are sharded along the mapped axes.
Nested pmap enables common advanced partitioning patterns:
pmap for data parallelism (splitting the batch) and an inner pmap for model parallelism (splitting model layers or parameters across devices). Collectives are then used within the appropriate axis scope (e.g., psum gradients across the data axis, all_gather activations across the model axis).pmap and specific collective communication patterns within each dimension.Here's a diagram illustrating a 2x2 device mesh for combined data/model parallelism:
A 2x2 device grid. The outer
pmapmaps across the 'data' axis (rows, dashed red lines indicate 'data' axis collectives). For each data shard, the innerpmapmaps across the 'model' axis (columns, solid blue lines indicate 'model' axis collectives).
Manual Device Assignment Considerations:
pmap typically assumes JAX manages the device assignment based on jax.devices() or the devices associated with the input arrays. For highly specific hardware topologies or performance tuning, you might sometimes need more explicit control over which physical device executes which part of the computation.
While pmap itself doesn't offer fine-grained manual computation placement within its execution, you can influence it by:
pmap using jax.device_put(array_shard, device). pmap will generally respect this placement if possible.pmap(devices=...): You can explicitly pass a list or array of devices to pmap to constrain which devices it runs on, although the assignment within that set is usually automatic.For partitioning schemes that deviate significantly from the standard SPMD model implicit in pmap (e.g., complex pipeline parallelism where different devices run fundamentally different stages), other JAX features or libraries might be more appropriate, such as careful use of jax.jit(device=...) or experimental libraries focused on explicit partitioning like jax.experimental.shard_map or frameworks built upon jax.Array. However, nested pmap combined with axis naming provides a powerful mechanism for many advanced, grid-based partitioning strategies directly within the pmap framework.
Mastering nested pmap requires careful attention to data sharding, axis naming, and the scope of collective operations. While it adds complexity compared to a single pmap, it unlocks the ability to scale computations across multi-dimensional accelerator arrays, enabling efficient training of larger models and processing of larger datasets than possible with simpler parallelism schemes.
Was this section helpful?
pmap, JAX Developers, 2024 (JAX Documentation) - Explains the pmap transformation, axis naming, and collective operations, forming the basis for understanding nested parallelism.shard_map in JAX, JAX authors, 2024 (JAX Documentation) - Details advanced data partitioning, explicit device meshes, and the jax.Array system for fine-grained distributed array management, including shard_map.© 2026 ApX Machine LearningEngineered with