jax.jitjitjitjitgradjax.gradgrad of grad)jax.value_and_grad)vmapjax.vmapin_axes, out_axes)vmapvmap with jit and gradvmappmapjax.pmapin_axes, out_axes)lax.psum, lax.pmean, etc.)pmap with other Transformationspmapped FunctionspmapModern hardware often includes multiple accelerators like GPUs or TPUs. While jax.jit optimizes code for a single device and jax.vmap handles batching efficiently, they don't inherently distribute computation across multiple devices. This chapter introduces jax.pmap (parallel map), JAX's function transformation for distributing computations by running them in parallel on different devices.
You will learn:
pmap uses.jax.pmap to execute the same function simultaneously on multiple devices, each operating on a different slice of data.in_axes and out_axes.jax.lax primitives) within pmap-transformed functions.pmap interacts with other JAX transformations like jit and grad.By the end of this chapter, you will be able to apply pmap to implement data parallelism for your JAX programs on multi-device systems.
5.1 Introduction to Data Parallelism (SPMD)
5.2 Introducing `jax.pmap`
5.3 Mapping Data to Devices (`in_axes`, `out_axes`)
5.4 Device Meshes and Axis Names
5.5 Collective Operations (`lax.psum`, `lax.pmean`, etc.)
5.6 Combining `pmap` with other Transformations
5.7 Debugging `pmap`ped Functions
5.8 Hands-on Practical: Parallel Computation
© 2026 ApX Machine LearningEngineered with