Modern hardware often includes multiple accelerators like GPUs or TPUs. While jax.jit
optimizes code for a single device and jax.vmap
handles batching efficiently, they don't inherently distribute computation across multiple devices. This chapter introduces jax.pmap
(parallel map), JAX's function transformation for distributing computations by running them in parallel on different devices.
You will learn:
pmap
uses.jax.pmap
to execute the same function simultaneously on multiple devices, each operating on a different slice of data.in_axes
and out_axes
.jax.lax
primitives) within pmap
-transformed functions.pmap
interacts with other JAX transformations like jit
and grad
.By the end of this chapter, you will be able to apply pmap
to implement data parallelism for your JAX programs on multi-device systems.
5.1 Introduction to Data Parallelism (SPMD)
5.2 Introducing `jax.pmap`
5.3 Mapping Data to Devices (`in_axes`, `out_axes`)
5.4 Device Meshes and Axis Names
5.5 Collective Operations (`lax.psum`, `lax.pmean`, etc.)
5.6 Combining `pmap` with other Transformations
5.7 Debugging `pmap`ped Functions
5.8 Hands-on Practical: Parallel Computation
© 2025 ApX Machine Learning