pmap in JAX is widely used for computations across multiple accelerators (GPUs or TPUs) connected to a single machine or host. This approach effectively accelerates many tasks, enabling substantial speedups through data parallelism. However, the most demanding machine learning challenges, especially training foundation models with billions or trillions of parameters, necessitate computational resources beyond what a single host can offer, even when equipped with multiple accelerators. This leads to the requirement for multi-host programming.
Multi-host programming in JAX extends the parallelism concepts we've discussed to coordinate computations across multiple independent machines connected via a network. Imagine a cluster of machines, each potentially equipped with several GPUs or TPUs. The goal is to orchestrate these machines to work together on a single, large computational task.
While similar to single-host pmap, multi-host execution introduces significant practical considerations:
The good news is that JAX's core primitives are designed with multi-host scenarios in mind, particularly for SPMD (Single-Program, Multiple-Data) workloads.
jax.devices() will return a list of all participating devices across all hosts. JAX automatically discovers the topology.pmap Across Hosts: The pmap function operates on this global list of devices. The same pmap code you write for a single host can often work in a multi-host setting. You map your function across all available devices, regardless of which host they reside on. Data sharding needs to account for the total global number of devices.jax.lax.psum, pmean, all_gather, etc., are the workhorses for exchanging information between devices. In a multi-host setup, these operations transparently handle the necessary network communication. On platforms like TPU Pods, these collectives are highly optimized to use the dedicated high-speed interconnects between TPU chips, minimizing the communication overhead.jax.process_index and jax.process_count. jax.process_count tells you how many distinct JAX processes (typically one per host) are participating in the computation. jax.process_index gives the unique rank (usually from 0 to process_count - 1) of the current process. These can be used, for example, to have each host load a different shard of a large dataset from disk.A view of
pmapspanning two hosts, each with four devices.pmapdistributes work across all eight global devices. Collective operations (psum) aggregate results across all devices, requiring communication over the network interconnect between hosts.
The Single-Program, Multiple-Data model extends naturally to multiple hosts. Each host launches and runs the exact same Python script containing the JAX code. Inside this script:
jax.process_index.pmap applies the function f to the local devices assigned to that host, operating on the corresponding data shard.pmap handle the necessary data exchange across all devices on all hosts participating in the collective, identified by the axis_name.Consider data-parallel training: each host loads a distinct portion of the global mini-batch. pmap computes gradients locally on each host's devices. A jax.lax.pmean over the mapped axis averages gradients across all devices globally before the optimizer step occurs (which might also be pmap'd or run identically on all hosts).
This section provides a foundation for multi-host programming. Actually implementing and running multi-host JAX jobs involves specific setup steps depending on your hardware (e.g., TPU Pods, multi-node GPU clusters) and infrastructure (cloud platforms, SLURM). While detailed implementation guides for specific environments are outside our current scope, understanding these concepts is significant. It shows how JAX's design scales from single-device execution up to large distributed systems using consistent programming abstractions like pmap and collective operations. Frameworks like Flax and Haiku, discussed later, often build upon these primitives to offer more convenient APIs for managing distributed training loops and model states in multi-host settings.
Was this section helpful?
© 2026 ApX Machine LearningAI Ethics & Transparency•