While pmap
is powerful for distributing computations across a single dimension of devices, complex parallelism strategies often require mapping computations over multi-dimensional grids of accelerators. This is where nesting pmap
calls becomes useful, allowing you to build sophisticated data and model parallelism configurations.
pmap
Imagine you have a grid of devices, perhaps 4 GPUs arranged logically as a 2x2 grid. You might want to apply data parallelism across the rows and model parallelism across the columns. This requires mapping your computation along two different axes simultaneously. JAX allows you to achieve this by nesting pmap
calls.
A function decorated with pmap
can itself call another function decorated with pmap
.
import jax
import jax.numpy as jnp
from jax.experimental import maps # Needed for mesh definitions, sometimes useful conceptually
# Assume 4 devices are available
num_devices = jax.local_device_count()
if num_devices < 4:
print(f"Warning: Need at least 4 devices for this example, found {num_devices}")
# Fallback for simpler execution if not enough devices
devices = jax.local_devices()[:1] * 4 # Use first device 4 times if needed
else:
devices = jax.local_devices()[:4]
# Reshape devices into a 2x2 grid (logical representation)
device_mesh = jax.device_put(jnp.arange(num_devices).reshape(2, 2), devices).devices
print("Logical Device Mesh:")
print(device_mesh)
# Inner function: operates on data specific to one 'model' shard
# We map this function across the 'model' axis (columns)
@jax.pmap(axis_name='model')
def inner_op(x_model_shard, params_model_shard):
# Example: Simple computation within a model shard
# Sum across the 'model' axis (all devices in this inner pmap call)
sum_across_models = jax.lax.psum(x_model_shard * params_model_shard, axis_name='model')
return sum_across_models * 2 # Some arbitrary operation
# Outer function: operates on data specific to one 'data' batch shard
# We map this function across the 'data' axis (rows)
@jax.pmap(axis_name='data')
def outer_op(x_data_shard, params_data_shard):
# x_data_shard arrives here shaped (num_model_shards, ...)
# params_data_shard arrives here shaped (num_model_shards, ...)
# Call the inner pmap over the 'model' axis
result = inner_op(x_data_shard, params_data_shard) # Implicitly maps over leading dim
# Example: Sum results across the 'data' axis (all devices in this outer pmap call)
final_result = jax.lax.psum(result, axis_name='data')
return final_result
# Prepare input data and parameters, sharded across both axes
# Shape: (num_data_shards, num_model_shards, ...)
data = jnp.arange(16.).reshape(2, 2, 2, 2) # (data_axis=2, model_axis=2, feature1=2, feature2=2)
params = jnp.ones(16.).reshape(2, 2, 2, 2) * 0.5
# Execute the nested pmap
# JAX automatically handles placing shards onto the corresponding devices
# based on the nested structure.
output = outer_op(data, params)
print("\nInput Data Shape:", data.shape)
print("Output Shape:", output.shape) # Output will be replicated across all devices
print("Output (one replica):")
print(output[0])
# Example verification (manual calculation for this specific op)
# For each inner_op call (fixed data shard, varying model shard):
# shard_0_0: (0,1,2,3) * 0.5 = (0, 0.5, 1, 1.5) -> psum = (0+0.5+1+1.5) = 3.0 -> *2 = 6.0
# shard_0_1: (4,5,6,7) * 0.5 = (2, 2.5, 3, 3.5) -> psum = (2+2.5+3+3.5) = 11.0 -> *2 = 22.0
# Inner Results for data_shard 0: (6.0, 22.0) <- Replicated across model axis
#
# shard_1_0: (8,9,10,11) * 0.5 = (4, 4.5, 5, 5.5) -> psum = (4+4.5+5+5.5) = 19.0 -> *2 = 38.0
# shard_1_1: (12,13,14,15)*0.5 = (6, 6.5, 7, 7.5) -> psum = (6+6.5+7+7.5) = 27.0 -> *2 = 54.0
# Inner Results for data_shard 1: (38.0, 54.0) <- Replicated across model axis
#
# Outer op psum over 'data' axis:
# Axis 0: psum(6.0, 38.0) = 44.0
# Axis 1: psum(22.0, 54.0) = 76.0
# Final result should be replicated across all devices, shape (2,2,2) e.g., [[44, 44],[44, 44]], [[76, 76],[76,76]] per original feature dims.
# Wait, the example code reshaped inner result. Let's re-trace.
# inner_op input: (2, 2) features per device pair along model axis.
# inner_op(x[0,0], p[0,0]) = psum([0,1]*0.5 + [2,3]*0.5, axis='model') = psum([0, 0.5] + [1, 1.5], axis='model') = psum([1, 2]) = 3.0? NO. psum is elementwise?
# Let's re-read psum: Sums array elements *across devices*.
# Okay, let's assume 4 devices. Mesh is [[0,1],[2,3]]
# outer_op called on devices [0,2] (data axis 0) and [1,3] (data axis 1) ?? No, outer_op maps the *function* across devices.
# Let's simplify the input data for clarity. Assume 1 feature dim.
# data = jnp.arange(4.).reshape(2, 2) # (data_axis=2, model_axis=2)
# params = jnp.ones(4.).reshape(2, 2) * 0.5
#
# Device mapping (example):
# Device 0 gets data[0,0]=0, params[0,0]=0.5
# Device 1 gets data[0,1]=1, params[0,1]=0.5
# Device 2 gets data[1,0]=2, params[1,0]=0.5
# Device 3 gets data[1,1]=3, params[1,1]=0.5
#
# outer_op is called, mapping over data axis.
# - Call for data axis index 0 runs on devices 0 and 1. It gets x_data_shard = (data[0,0], data[0,1]) = (0, 1) and params_data_shard = (params[0,0], params[0,1]) = (0.5, 0.5)
# - Call for data axis index 1 runs on devices 2 and 3. It gets x_data_shard = (data[1,0], data[1,1]) = (2, 3) and params_data_shard = (params[1,0], params[1,1]) = (0.5, 0.5)
#
# Inside outer_op (e.g., for data axis index 0 on devices 0, 1):
# Calls inner_op(x_data_shard=(0,1), params_data_shard=(0.5, 0.5))
# inner_op is pmapped over 'model' axis.
# - Call for model axis index 0 runs on device 0. Gets x_model_shard=0, params_model_shard=0.5. Computes 0 * 0.5 = 0.
# - Call for model axis index 1 runs on device 1. Gets x_model_shard=1, params_model_shard=0.5. Computes 1 * 0.5 = 0.5.
# Collective psum over 'model' axis (devices 0, 1): psum(0, 0.5) = 0.5.
# Result on both devices 0 and 1 is 0.5 * 2 = 1.0.
#
# Inside outer_op (e.g., for data axis index 1 on devices 2, 3):
# Calls inner_op(x_data_shard=(2,3), params_data_shard=(0.5, 0.5))
# inner_op is pmapped over 'model' axis.
# - Call for model axis index 0 runs on device 2. Gets x_model_shard=2, params_model_shard=0.5. Computes 2 * 0.5 = 1.0.
# - Call for model axis index 1 runs on device 3. Gets x_model_shard=3, params_model_shard=0.5. Computes 3 * 0.5 = 1.5.
# Collective psum over 'model' axis (devices 2, 3): psum(1.0, 1.5) = 2.5.
# Result on both devices 2 and 3 is 2.5 * 2 = 5.0.
#
# Back in outer_op:
# - For data axis index 0 (devices 0, 1), result = 1.0.
# - For data axis index 1 (devices 2, 3), result = 5.0.
# Collective psum over 'data' axis (all devices 0, 1, 2, 3): This requires care. Which devices participate in which collective?
# The 'data' axis in the *outer* pmap logically groups devices. Let's assume devices [[0,1],[2,3]]. The 'data' axis runs over the rows.
# So, devices {0,1} form one group for the 'data' psum, and devices {2,3} form another. But that doesn't seem right.
# Let's consult JAX docs on nested pmap axes... `pmap` nests like `vmap`. The axis name scopes locally. The outer `pmap` defines the 'data' axis *across all participating devices*. The inner `pmap` defines the 'model' axis *within the subset of devices processing a single outer iteration*.
# Let's rethink the device assignment/mesh. `pmap` by default uses `jax.devices()`. Let's assume 4 devices linearly [0, 1, 2, 3].
# `outer_op` maps over the first dimension of the input (size 2).
# - Iteration 0 (data axis index 0) runs on devices [0, 1]. Receives data[0]=((0,1),(2,3)) and params[0]=((0.5,0.5),(0.5,0.5)). x_data_shard = data[0], params_data_shard = params[0] on these devices.
# - Iteration 1 (data axis index 1) runs on devices [2, 3]. Receives data[1]=((8,9),(10,11)) and params[1]=((0.5,0.5),(0.5,0.5)). x_data_shard = data[1], params_data_shard = params[1] on these devices.
# Inside outer_op, iteration 0 (devices 0, 1):
# Calls `inner_op(x_data_shard, params_data_shard)`. `inner_op` is mapped over the first dimension of *its* input (`x_data_shard`, size 2, corresponding to 'model' axis).
# - Inner iteration 0 (model axis index 0) runs on device 0. Gets x_model_shard=x_data_shard[0]=(0,1), params_model_shard=params_data_shard[0]=(0.5, 0.5). Computes x*p = (0, 0.5).
# - Inner iteration 1 (model axis index 1) runs on device 1. Gets x_model_shard=x_data_shard[1]=(2,3), params_model_shard=params_data_shard[1]=(0.5, 0.5). Computes x*p = (1, 1.5).
# `jax.lax.psum(..., axis_name='model')` sums across the devices participating in this `inner_op` call (devices 0, 1).
# `psum((0, 0.5), (1, 1.5))` -> `(0+1, 0.5+1.5)` = `(1, 2)`. This result `(1,2)` exists on both device 0 and device 1.
# `result = (1, 2) * 2 = (2, 4)`. This exists on devices 0 and 1.
# Inside outer_op, iteration 1 (devices 2, 3):
# Calls `inner_op(x_data_shard, params_data_shard)`. `inner_op` is mapped over the first dimension of *its* input (`x_data_shard`, size 2, corresponding to 'model' axis).
# - Inner iteration 0 (model axis index 0) runs on device 2. Gets x_model_shard=x_data_shard[0]=(8,9), params_model_shard=params_data_shard[0]=(0.5, 0.5). Computes x*p = (4, 4.5).
# - Inner iteration 1 (model axis index 1) runs on device 3. Gets x_model_shard=x_data_shard[1]=(10,11), params_model_shard=params_data_shard[1]=(0.5, 0.5). Computes x*p = (5, 5.5).
# `jax.lax.psum(..., axis_name='model')` sums across the devices participating in this `inner_op` call (devices 2, 3).
# `psum((4, 4.5), (5, 5.5))` -> `(4+5, 4.5+5.5)` = `(9, 10)`. This result `(9,10)` exists on both device 2 and device 3.
# `result = (9, 10) * 2 = (18, 20)`. This exists on devices 2 and 3.
# Back in `outer_op`:
# `jax.lax.psum(result, axis_name='data')` sums across devices participating in the `outer_op` call (all devices 0, 1, 2, 3).
# The value `result` on device 0 is (2, 4). On device 1 is (2, 4). On device 2 is (18, 20). On device 3 is (18, 20).
# `psum((2,4), (2,4), (18,20), (18,20))` -> `(2+2+18+18, 4+4+20+20)` = `(40, 48)`.
# This final result `(40, 48)` should be present on all devices 0, 1, 2, 3.
# Let's re-run the code mentally with the simpler input:
# data = jnp.arange(4.).reshape(2, 2)
# params = jnp.ones(4.).reshape(2, 2) * 0.5
# Device 0: data=0, param=0.5. outer_iter=0, inner_iter=0. x*p=0.
# Device 1: data=1, param=0.5. outer_iter=0, inner_iter=1. x*p=0.5.
# Device 2: data=2, param=0.5. outer_iter=1, inner_iter=0. x*p=1.0.
# Device 3: data=3, param=0.5. outer_iter=1, inner_iter=1. x*p=1.5.
# outer_iter=0 (dev 0,1): inner_op receives x=(0,1), p=(0.5,0.5).
# inner_iter=0 (dev 0): receives x=0, p=0.5 -> computes 0.
# inner_iter=1 (dev 1): receives x=1, p=0.5 -> computes 0.5.
# inner_psum over 'model' (dev 0,1): psum(0, 0.5) = 0.5. Result on dev 0,1 = 0.5 * 2 = 1.0.
# outer_iter=1 (dev 2,3): inner_op receives x=(2,3), p=(0.5,0.5).
# inner_iter=0 (dev 2): receives x=2, p=0.5 -> computes 1.0.
# inner_iter=1 (dev 3): receives x=3, p=0.5 -> computes 1.5.
# inner_psum over 'model' (dev 2,3): psum(1.0, 1.5) = 2.5. Result on dev 2,3 = 2.5 * 2 = 5.0.
# outer_psum over 'data' (dev 0,1,2,3): psum(1.0, 1.0, 5.0, 5.0) = 12.0.
# Final output should be 12.0 on all devices.
# Let's reset the example code slightly to use this simpler data.
data_simple = jnp.arange(4.).reshape(2, 2) # (data_axis=2, model_axis=2)
params_simple = jnp.ones(4.).reshape(2, 2) * 0.5
# Rerun with simpler data
output_simple = outer_op(data_simple, params_simple)
print("\n--- Simple Example ---")
print("Input Data Shape:", data_simple.shape)
print("Output Shape:", output_simple.shape)
print("Output (one replica):")
print(output_simple[0]) # Expect 12.0
Key points about nested pmap
:
pmap
are local to that specific call and the devices participating in it for that outer iteration. An outer axis_name
refers to collectives across all devices involved in the outer pmap
.pmap
axes. In the example above, data
and params
had shape (2, 2, ...)
, mapping to a 2x2 logical device grid. The first dimension corresponds to the data
axis of the outer pmap
, and the second dimension corresponds to the model
axis of the inner pmap
.psum
must specify the axis_name
they operate over. This tells JAX which set of devices should participate in the communication (e.g., sum across the 'model' axis within an inner_op
call, or sum across the 'data' axis within the outer_op
call).pmap
While nesting pmap
allows mapping computations onto multi-dimensional device grids, pmap
itself follows the SPMD principle: the same program runs everywhere, but operates on different slices of data. The partitioning is implicitly defined by how the input arrays are sharded along the mapped axes.
Nested pmap
enables common advanced partitioning patterns:
pmap
for data parallelism (splitting the batch) and an inner pmap
for model parallelism (splitting model layers or parameters across devices). Collectives are then used within the appropriate axis scope (e.g., psum
gradients across the data axis, all_gather
activations across the model axis).pmap
and specific collective communication patterns within each dimension.Here's a conceptual diagram illustrating a 2x2 device mesh for combined data/model parallelism:
A 2x2 device grid. The outer
pmap
maps across the 'data' axis (conceptually rows, dashed red lines indicate 'data' axis collectives). For each data shard, the innerpmap
maps across the 'model' axis (conceptually columns, solid blue lines indicate 'model' axis collectives).
Manual Device Assignment Considerations:
pmap
typically assumes JAX manages the device assignment based on jax.devices()
or the devices associated with the input arrays. For highly specific hardware topologies or performance tuning, you might sometimes need more explicit control over which physical device executes which part of the computation.
While pmap
itself doesn't offer fine-grained manual computation placement within its execution, you can influence it by:
pmap
using jax.device_put(array_shard, device)
. pmap
will generally respect this placement if possible.pmap(devices=...)
: You can explicitly pass a list or array of devices to pmap
to constrain which devices it runs on, although the assignment within that set is usually automatic.For partitioning schemes that deviate significantly from the standard SPMD model implicit in pmap
(e.g., complex pipeline parallelism where different devices run fundamentally different stages), other JAX features or libraries might be more appropriate, such as careful use of jax.jit(device=...)
or experimental libraries focused on explicit partitioning like jax.experimental.shard_map
or frameworks built upon jax.Array
. However, nested pmap
combined with axis naming provides a powerful mechanism for many advanced, grid-based partitioning strategies directly within the pmap
framework.
Mastering nested pmap
requires careful attention to data sharding, axis naming, and the scope of collective operations. While it adds complexity compared to a single pmap
, it unlocks the ability to scale computations across multi-dimensional accelerator arrays, enabling efficient training of larger models and processing of larger datasets than possible with simpler parallelism schemes.
© 2025 ApX Machine Learning