趋近智
pmap在JAX中广泛应用于在连接到一台机器或主机上的多个加速器(GPU或TPU)之间进行并行计算。这种方法对许多任务非常有效,能够通过数据并行实现显著加速。然而,最严峻的机器学习 (machine learning)挑战,特别是训练拥有数十亿或数万亿参数 (parameter)的基础模型,需要的计算资源远远超出了单台主机所能提供的范围,即使是配备了多个加速器的主机也无法满足。这促使我们转向多主机编程。
JAX中的多主机编程将我们之前讨论过的并行性思路进行了扩展,以便通过网络协调多台独立机器上的计算。设想一下一个机器集群,每台机器都可能配备多个GPU或TPU。目标是调度这些机器,共同完成一个大型计算任务。
虽然与单主机pmap类似,多主机执行带来了重要的实际考量:
好消息是,JAX的核心原语在设计时就考虑到了多主机场景,特别是针对SPMD(单程序,多数据)工作负载。
jax.devices()将返回跨越所有主机的所有参与设备的列表。JAX会自动识别设备布局。pmap: pmap函数作用于这个全局设备列表。为单个主机编写的pmap代码通常可以在多主机环境中运行。你可以将函数映射到所有可用设备上,无论它们位于哪个主机。数据分片需要考虑到全局设备的总数。jax.lax.psum、pmean、all_gather等集合通信原语是设备间信息交换的主要方式。在多主机设置中,这些操作透明地处理所需的网络通信。在TPU Pod等平台上,这些集合操作经过高度优化,以便使用TPU芯片之间专用的高速互连,从而减少通信开销。jax.process_index和jax.process_count。jax.process_count会告知有多少个不同的JAX进程(通常每个主机一个)参与了计算。jax.process_index给出当前进程的唯一序号(通常从0到process_count - 1)。这些可用于,例如,让每个主机从磁盘加载大型数据集的不同分片。一个
pmap作用于两个主机(每个主机有四个设备)的视图。pmap将工作分配到所有八个全局设备上。集合操作(psum)聚合所有设备上的结果,这需要通过主机间的网络互连进行通信。
单程序多数据模型自然地适用于多主机。每个主机都会启动并运行包含JAX代码的完全相同的Python脚本。在此脚本内部:
jax.process_index进行区分。pmap将函数f应用于分配给该主机的本地设备,对相应的数据分片进行操作。pmap内部的集合操作处理所有参与集合的主机上所有设备之间必要的数据交换,通过axis_name指定。以数据并行训练为例:每个主机加载全局小批量数据的不同部分。pmap在每个主机的设备上本地计算梯度。jax.lax.pmean会沿着映射轴对梯度求平均值,跨越所有全局设备,在优化器步骤执行之前(优化器步骤也可能通过pmap执行或在所有主机上相同运行)。
本节为多主机编程提供了基本内容。实际实施和运行多主机JAX作业涉及具体的设置步骤,这些步骤取决于你的硬件(例如TPU Pods、多节点GPU集群)以及基础设施(云平台、SLURM)。尽管针对特定环境的详细实施指南超出了我们目前的范围,但理解这些内容很重要。这展示了JAX的设计如何使用pmap和集合操作等一致的编程抽象,从单设备执行扩展到大型分布式系统。后面会讨论的Flax和Haiku等框架通常在这些原语之上构建,以提供更方便的API,用于在多主机设置中管理分布式训练循环和模型状态。
这部分内容有帮助吗?
© 2026 ApX Machine LearningAI伦理与透明度•