pmap在JAX中广泛应用于在连接到一台机器或主机上的多个加速器(GPU或TPU)之间进行并行计算。这种方法对许多任务非常有效,能够通过数据并行实现显著加速。然而,最严峻的机器学习挑战,特别是训练拥有数十亿或数万亿参数的基础模型,需要的计算资源远远超出了单台主机所能提供的范围,即使是配备了多个加速器的主机也无法满足。这促使我们转向多主机编程。JAX中的多主机编程将我们之前讨论过的并行性思路进行了扩展,以便通过网络协调多台独立机器上的计算。设想一下一个机器集群,每台机器都可能配备多个GPU或TPU。目标是调度这些机器,共同完成一个大型计算任务。扩展到多个节点虽然与单主机pmap类似,多主机执行带来了重要的实际考量:网络通信: 在单个主机内部,设备通常通过PCIe或NVLink等快速互连方式进行通信。主机之间,通信则依赖于网络(例如以太网、InfiniBand或专用TPU互连)。网络延迟和带宽通常比主机内部互连差几个数量级,这使得跨主机通信可能成为性能瓶颈。高效的集合通信算法变得非常重要。同步与协调: 你不再处理由单个操作系统管理的线程。相反,你有运行在不同机器上的独立Python进程,它们需要协调各自的执行。这包括确保进程正确启动、相互发现,并在特定时刻(例如在集合操作期间)进行同步。基础设施与作业启动: 设置和管理多主机环境更为复杂。你需要有方法在所有参与主机上启动相同的JAX脚本、配置网络,并可能处理故障。这通常涉及使用集群管理系统(如SLURM)或云服务商专用工具(例如Google Cloud管理TPU Pod切片的工具)。JAX本身不提供这些基础设施工具;它提供的是在此类环境中运行的编程原语。多主机环境中的JAX原语好消息是,JAX的核心原语在设计时就考虑到了多主机场景,特别是针对SPMD(单程序,多数据)工作负载。全局设备感知: 当JAX程序在一个正确配置的多主机环境(例如TPU Pod)中运行时,jax.devices()将返回跨越所有主机的所有参与设备的列表。JAX会自动识别设备布局。跨主机pmap: pmap函数作用于这个全局设备列表。为单个主机编写的pmap代码通常可以在多主机环境中运行。你可以将函数映射到所有可用设备上,无论它们位于哪个主机。数据分片需要考虑到全局设备的总数。跨主机集合操作: jax.lax.psum、pmean、all_gather等集合通信原语是设备间信息交换的主要方式。在多主机设置中,这些操作透明地处理所需的网络通信。在TPU Pod等平台上,这些集合操作经过高度优化,以便使用TPU芯片之间专用的高速互连,从而减少通信开销。进程识别: JAX提供了jax.process_index和jax.process_count。jax.process_count会告知有多少个不同的JAX进程(通常每个主机一个)参与了计算。jax.process_index给出当前进程的唯一序号(通常从0到process_count - 1)。这些可用于,例如,让每个主机从磁盘加载大型数据集的不同分片。digraph G {rankdir=TB;bgcolor="white";node [shape=record,style=filled,fillcolor="#e9ecef",fontname="sans-serif"];subgraph cluster0 {label="主机 0 (process_index=0)";bgcolor="#f8f9fa";node [fillcolor="#a5d8ff"];H0D0 [label="设备 0"];H0D1 [label="设备 1"];H0D2 [label="设备 2"];H0D3 [label="设备 3"];} subgraph cluster1 {label="主机 1 (process_index=1)";bgcolor="#f8f9fa";node [fillcolor="#a5d8ff"];H1D0 [label="设备 4"];H1D1 [label="设备 5"];H1D2 [label="设备 6"];H1D3 [label="设备 7"];} pmap[label="pmap(f, axis_name='batch')",shape=box,fillcolor="#ffec99"];psum[label="psum(..., axis_name='batch')",shape=diamond,fillcolor="#96f2d7"];pmap -> H0D0;pmap -> H0D1;pmap -> H0D2;pmap -> H0D3;pmap -> H1D0;pmap -> H1D1;pmap -> H1D2;pmap -> H1D3;H0D0 -> psum;H0D1 -> psum;H0D2 -> psum;H0D3 -> psum;H1D0 -> psum;H1D1 -> psum;H1D2 -> psum;H1D3 -> psum;psum -> pmap [label="聚合结果",style=dashed,color="#0ca678",fontsize=10];} 一个pmap作用于两个主机(每个主机有四个设备)的视图。pmap将工作分配到所有八个全局设备上。集合操作(psum)聚合所有设备上的结果,这需要通过主机间的网络互连进行通信。SPMD 的扩展单程序多数据模型自然地适用于多主机。每个主机都会启动并运行包含JAX代码的完全相同的Python脚本。在此脚本内部:数据加载可能因主机而异,使用jax.process_index进行区分。pmap将函数f应用于分配给该主机的本地设备,对相应的数据分片进行操作。pmap内部的集合操作处理所有参与集合的主机上所有设备之间必要的数据交换,通过axis_name指定。以数据并行训练为例:每个主机加载全局小批量数据的不同部分。pmap在每个主机的设备上本地计算梯度。jax.lax.pmean会沿着映射轴对梯度求平均值,跨越所有全局设备,在优化器步骤执行之前(优化器步骤也可能通过pmap执行或在所有主机上相同运行)。展望本节为多主机编程提供了基本内容。实际实施和运行多主机JAX作业涉及具体的设置步骤,这些步骤取决于你的硬件(例如TPU Pods、多节点GPU集群)以及基础设施(云平台、SLURM)。尽管针对特定环境的详细实施指南超出了我们目前的范围,但理解这些内容很重要。这展示了JAX的设计如何使用pmap和集合操作等一致的编程抽象,从单设备执行扩展到大型分布式系统。后面会讨论的Flax和Haiku等框架通常在这些原语之上构建,以提供更方便的API,用于在多主机设置中管理分布式训练循环和模型状态。