jax.jitjitjitjitgradjax.gradgrad of grad)jax.value_and_grad)vmapjax.vmapin_axes, out_axes)vmapvmap with jit and gradvmappmapjax.pmapin_axes, out_axes)lax.psum, lax.pmean, etc.)pmap with other Transformationspmapped Functionspmapped FunctionsDebugging code that runs in parallel across multiple devices introduces complexities beyond those found in single-device execution or standard Python. When you wrap a function with jax.pmap, you're moving from a single execution stream to multiple streams operating simultaneously under the Single Program, Multiple Data (SPMD) model. Errors might occur not just within the logic of your function, but also in how data is distributed, processed across devices, and gathered. Furthermore, since pmap implicitly compiles your function using XLA (similar to jit), you also need to consider the challenges associated with debugging compiled code.
Here are common issues encountered with pmap and strategies to address them:
Shape Mismatches: This is perhaps the most frequent source of errors.
pmap splits input arrays along the axis specified in in_axes. The size of this axis must equal the number of devices JAX is using for the pmap execution (jax.local_device_count() usually, unless using multi-host setups). If you have 4 GPUs and map over an array with in_axes=0, the first dimension of that array must be 4. A mismatch will cause an error during the distribution phase.out_axes. If the function on each device produces an output shape that's inconsistent with how pmap expects to gather it, errors can occur. Ensure the shape produced per-device makes sense when stacked.None in in_axes are replicated, not split. Ensure you correctly distinguish between data that should be divided among devices (sharded) and data that should be identical on all devices (replicated). Providing data intended for sharding with in_axes=None (or vice-versa) leads to incorrect shapes inside the per-device computation.Collective Operation Issues: Functions like lax.psum, lax.pmean, lax.all_gather coordinate operations across all devices involved in the pmap.
axis_name: The axis_name used in the collective function (e.g., lax.psum(x, axis_name='my_devices')) must exactly match the name provided in the pmap call (e.g., pmap(..., axis_name='my_devices')). Typographical errors are common.pmapped function causes some devices to skip the collective call. Remember, under SPMD, all devices execute the same code; conditional logic should typically depend only on replicated values or the device ID (lax.axis_index) if divergence is intended and carefully managed.Tracing and Compilation Errors: Since pmap compiles the function, you might encounter errors similar to those with jit.
if) are problematic. Refer back to the debugging techniques for jit, such as using jax.lax.cond or ensuring conditionals operate on static arguments.pmapped function (like modifying external variables or printing with standard Python print) behave unpredictably or cause errors.pmapDebugging pmap often involves simplifying the problem and carefully inspecting the boundaries where data moves between the host and devices or between devices themselves.
Simplify: Remove pmap First: Before debugging the parallel version, ensure the underlying function works correctly on a single device with a single shard of data.
pmap (and potentially without jit initially) on this single slice.Reduce Device Count: If the single-slice version works, try pmap with the smallest possible number of devices (even just 1 or 2, if your hardware allows and the logic isn't dependent on a specific count). If the error appears only with multiple devices, it's more likely related to data distribution, gathering, or collectives.
Check Shapes Diligently: Use .shape extensively outside the pmapped function.
pmap. Does the dimension specified in in_axes match jax.local_device_count()?pmap. Does it match your expectation based on the per-device output shape and the out_axes specification?import jax
import jax.numpy as jnp
from jax import pmap
# Example: Assume we have 4 devices
num_devices = 4 # Replace with jax.local_device_count() in practice
# Data intended to be split across 4 devices
sharded_data = jnp.arange(4 * 10).reshape((num_devices, 10))
# Data intended to be replicated on all devices
replicated_data = jnp.array(5.0)
def my_func(x, y):
# x is sharded (per-device shape is (10,)), y is replicated (scalar)
return x * y + lax.axis_index('batch') # Use device ID
pmapped_func = pmap(my_func, in_axes=(0, None), out_axes=0, axis_name='batch')
print("Shape of sharded_data:", sharded_data.shape)
# Expected: (4, 10) - Check if first dim matches num_devices
assert sharded_data.shape[0] == num_devices
print("Shape of replicated_data:", replicated_data.shape)
# Expected: () for scalar, or shape that doesn't match num_devices in axis 0
output = pmapped_func(sharded_data, replicated_data)
print("Shape of output:", output.shape)
# Expected: (4, 10) - Because out_axes=0 stacks the (10,) results from each device
Use jax.debug for Inspecting Values Inside: Standard Python print() inside pmap executes on each device independently and often prints asynchronously to the host, making output interleaved and hard to follow. It can also interfere with execution. Use JAX's debugging utilities instead:
jax.debug.print: Prints values from compiled/pmapped functions, labelling output with the device source. It handles synchronization better than print(). You can use it conditionally based on device ID (lax.axis_index) to reduce noise.import jax
import jax.numpy as jnp
from jax import pmap, lax
def func_with_debug_print(x):
intermediate = x * 2
# Print intermediate value only from device 0
if lax.axis_index('data_axis') == 0:
jax.debug.print("Device 0 intermediate: {val}", val=intermediate)
# Collective operation
result = lax.psum(intermediate, axis_name='data_axis')
jax.debug.print("Device {id} result after psum: {res}", id=lax.axis_index('data_axis'), res=result)
return result
# Assume 2 devices
data = jnp.arange(2 * 3).reshape((2, 3))
pmapped_func = pmap(func_with_debug_print, axis_name='data_axis')
# On execution, you'll see labeled prints
output = pmapped_func(data)
# Output might show:
# Device 0 intermediate: [0 2 4]
# Device 0 result after psum: [ 6 8 10]
# Device 1 result after psum: [ 6 8 10]
jax.debug.breakpoint(): Pauses execution on all devices when hit, allowing inspection via pdb. Use sparingly as it halts everything.Test Collective Operations Carefully: If you suspect a collective is causing issues:
pmapped function that only performs that collective with simple input data.jax.debug.print to see the input to the collective on each device and the output from it.axis_name.Consider Temporarily Disabling JIT: While pmap requires compilation, if you face obscure compilation errors specifically when using pmap that don't occur with jit alone, you might try jax.disable_jit() in a limited scope around your pmap call during debugging. This forces a different execution path that might yield a more understandable Python error, but be aware that this is not how pmap normally runs and might mask the real issue or introduce different behavior. Use this as a last resort for diagnosis, not as a solution.
Debugging pmap requires patience and a systematic approach. By simplifying the problem, checking data shapes carefully, and using appropriate debugging tools like jax.debug.print, you can effectively isolate and resolve issues related to multi-device parallel execution.
Was this section helpful?
pmap, The JAX Developers, 2024 - Official JAX tutorial on pmap for multi-device execution.© 2026 ApX Machine LearningEngineered with