Debugging code that runs in parallel across multiple devices introduces complexities beyond those found in single-device execution or standard Python. When you wrap a function with jax.pmap
, you're moving from a single execution stream to multiple streams operating simultaneously under the Single Program, Multiple Data (SPMD) model. Errors might occur not just within the logic of your function, but also in how data is distributed, processed across devices, and gathered. Furthermore, since pmap
implicitly compiles your function using XLA (similar to jit
), you also need to consider the challenges associated with debugging compiled code.
Here are common issues encountered with pmap
and strategies to address them:
Shape Mismatches: This is perhaps the most frequent source of errors.
pmap
splits input arrays along the axis specified in in_axes
. The size of this axis must equal the number of devices JAX is using for the pmap
execution (jax.local_device_count()
usually, unless using multi-host setups). If you have 4 GPUs and map over an array with in_axes=0
, the first dimension of that array must be 4. A mismatch will cause an error during the distribution phase.out_axes
. If the function on each device produces an output shape that's inconsistent with how pmap
expects to gather it, errors can occur. Ensure the shape produced per-device makes sense when stacked.None
in in_axes
are replicated, not split. Ensure you correctly distinguish between data that should be divided among devices (sharded) and data that should be identical on all devices (replicated). Providing data intended for sharding with in_axes=None
(or vice-versa) leads to incorrect shapes inside the per-device computation.Collective Operation Issues: Functions like lax.psum
, lax.pmean
, lax.all_gather
coordinate operations across all devices involved in the pmap
.
axis_name
: The axis_name
used in the collective function (e.g., lax.psum(x, axis_name='my_devices')
) must exactly match the name provided in the pmap
call (e.g., pmap(..., axis_name='my_devices')
). Typographical errors are common.pmap
ped function causes some devices to skip the collective call. Remember, under SPMD, all devices execute the same code; conditional logic should typically depend only on replicated values or the device ID (lax.axis_index
) if divergence is intended and carefully managed.Tracing and Compilation Errors: Since pmap
compiles the function, you might encounter errors similar to those with jit
.
if
) are problematic. Refer back to the debugging techniques for jit
, such as using jax.lax.cond
or ensuring conditionals operate on static arguments.pmap
ped function (like modifying external variables or printing with standard Python print
) behave unpredictably or cause errors.pmap
Debugging pmap
often involves simplifying the problem and carefully inspecting the boundaries where data moves between the host and devices or between devices themselves.
Simplify: Remove pmap
First: Before debugging the parallel version, ensure the underlying function works correctly on a single device with a single shard of data.
pmap
(and potentially without jit
initially) on this single slice.Reduce Device Count: If the single-slice version works, try pmap
with the smallest possible number of devices (even just 1 or 2, if your hardware allows and the logic isn't dependent on a specific count). If the error appears only with multiple devices, it's more likely related to data distribution, gathering, or collectives.
Check Shapes Diligently: Use .shape
extensively outside the pmap
ped function.
pmap
. Does the dimension specified in in_axes
match jax.local_device_count()
?pmap
. Does it match your expectation based on the per-device output shape and the out_axes
specification?import jax
import jax.numpy as jnp
from jax import pmap
# Example: Assume we have 4 devices
num_devices = 4 # Replace with jax.local_device_count() in practice
# Data intended to be split across 4 devices
sharded_data = jnp.arange(4 * 10).reshape((num_devices, 10))
# Data intended to be replicated on all devices
replicated_data = jnp.array(5.0)
def my_func(x, y):
# x is sharded (per-device shape is (10,)), y is replicated (scalar)
return x * y + lax.axis_index('batch') # Use device ID
pmapped_func = pmap(my_func, in_axes=(0, None), out_axes=0, axis_name='batch')
print("Shape of sharded_data:", sharded_data.shape)
# Expected: (4, 10) - Check if first dim matches num_devices
assert sharded_data.shape[0] == num_devices
print("Shape of replicated_data:", replicated_data.shape)
# Expected: () for scalar, or shape that doesn't match num_devices in axis 0
output = pmapped_func(sharded_data, replicated_data)
print("Shape of output:", output.shape)
# Expected: (4, 10) - Because out_axes=0 stacks the (10,) results from each device
Use jax.debug
for Inspecting Values Inside: Standard Python print()
inside pmap
executes on each device independently and often prints asynchronously to the host, making output interleaved and hard to follow. It can also interfere with execution. Use JAX's debugging utilities instead:
jax.debug.print
: Prints values from compiled/pmapped functions, labelling output with the device source. It handles synchronization better than print()
. You can use it conditionally based on device ID (lax.axis_index
) to reduce noise.import jax
import jax.numpy as jnp
from jax import pmap, lax
def func_with_debug_print(x):
intermediate = x * 2
# Print intermediate value only from device 0
if lax.axis_index('data_axis') == 0:
jax.debug.print("Device 0 intermediate: {val}", val=intermediate)
# Collective operation
result = lax.psum(intermediate, axis_name='data_axis')
jax.debug.print("Device {id} result after psum: {res}", id=lax.axis_index('data_axis'), res=result)
return result
# Assume 2 devices
data = jnp.arange(2 * 3).reshape((2, 3))
pmapped_func = pmap(func_with_debug_print, axis_name='data_axis')
# On execution, you'll see labeled prints
output = pmapped_func(data)
# Output might show:
# Device 0 intermediate: [0 2 4]
# Device 0 result after psum: [ 6 8 10]
# Device 1 result after psum: [ 6 8 10]
jax.debug.breakpoint()
: Pauses execution on all devices when hit, allowing inspection via pdb. Use sparingly as it halts everything.Test Collective Operations Carefully: If you suspect a collective is causing issues:
pmap
ped function that only performs that collective with simple input data.jax.debug.print
to see the input to the collective on each device and the output from it.axis_name
.Consider Temporarily Disabling JIT: While pmap
requires compilation, if you face obscure compilation errors specifically when using pmap
that don't occur with jit
alone, you might try jax.disable_jit()
in a limited scope around your pmap
call during debugging. This forces a different execution path that might yield a more understandable Python error, but be aware that this is not how pmap
normally runs and might mask the real issue or introduce different behavior. Use this as a last resort for diagnosis, not as a solution.
Debugging pmap
requires patience and a systematic approach. By simplifying the problem, checking data shapes carefully, and using appropriate debugging tools like jax.debug.print
, you can effectively isolate and resolve issues related to multi-device parallel execution.
© 2025 ApX Machine Learning