So far, our exploration of pmap
has focused on leveraging multiple accelerators (GPUs or TPUs) attached to a single machine or host. This is highly effective for many tasks, allowing significant speedups through data parallelism. However, the most demanding machine learning challenges, particularly training foundation models with billions or trillions of parameters, require computational resources that exceed what a single host can provide, even one packed with multiple accelerators. This brings us to the realm of multi-host programming.
Multi-host programming in JAX extends the parallelism concepts we've discussed to coordinate computations across multiple independent machines connected via a network. Imagine a cluster of machines, each potentially equipped with several GPUs or TPUs. The goal is to orchestrate these machines to work together on a single, large computational task.
While conceptually similar to single-host pmap
, multi-host execution introduces significant practical considerations:
The good news is that JAX's core primitives are designed with multi-host scenarios in mind, particularly for SPMD (Single-Program, Multiple-Data) workloads.
jax.devices()
will return a list of all participating devices across all hosts. JAX automatically discovers the topology.pmap
Across Hosts: The pmap
function operates on this global list of devices. The same pmap
code you write for a single host can often work seamlessly in a multi-host setting. You map your function across all available devices, regardless of which host they reside on. Data sharding needs to account for the total global number of devices.jax.lax.psum
, pmean
, all_gather
, etc., are the workhorses for exchanging information between devices. In a multi-host setup, these operations transparently handle the necessary network communication. On platforms like TPU Pods, these collectives are highly optimized to use the dedicated high-speed interconnects between TPU chips, minimizing the communication overhead.jax.process_index
and jax.process_count
. jax.process_count
tells you how many distinct JAX processes (typically one per host) are participating in the computation. jax.process_index
gives the unique rank (usually from 0 to process_count - 1
) of the current process. These can be used, for example, to have each host load a different shard of a large dataset from disk.A conceptual view of
pmap
spanning two hosts, each with four devices.pmap
distributes work across all eight global devices. Collective operations (psum
) aggregate results across all devices, requiring communication over the network interconnect between hosts.
The Single-Program, Multiple-Data model extends naturally to multiple hosts. Each host launches and runs the exact same Python script containing the JAX code. Inside this script:
jax.process_index
.pmap
applies the function f
to the local devices assigned to that host, operating on the corresponding data shard.pmap
handle the necessary data exchange across all devices on all hosts participating in the collective, identified by the axis_name
.Consider data-parallel training: each host loads a distinct portion of the global mini-batch. pmap
computes gradients locally on each host's devices. A jax.lax.pmean
over the mapped axis averages gradients across all devices globally before the optimizer step occurs (which might also be pmap
'd or run identically on all hosts).
This section provides a conceptual foundation for multi-host programming. Actually implementing and running multi-host JAX jobs involves specific setup steps depending on your hardware (e.g., TPU Pods, multi-node GPU clusters) and infrastructure (cloud platforms, SLURM). While detailed implementation guides for specific environments are outside our current scope, understanding these concepts is significant. It shows how JAX's design scales from single-device execution up to large distributed systems using consistent programming abstractions like pmap
and collective operations. Frameworks like Flax and Haiku, discussed later, often build upon these primitives to offer more convenient APIs for managing distributed training loops and model states in multi-host settings.
© 2025 ApX Machine Learning