趋近智
现代机器学习的规模经常超出单个计算设备的能力。无论是处理庞大的数据集还是包含数十亿参数的模型,分散计算负荷都变得必不可少。并行处理的基本原理用于应对这些大规模问题,为理解 JAX 的分布式计算功能,尤其是 pmap,提供依据。
并行处理,本质上涉及将计算任务划分为更小、独立或半独立的部分,这些部分可以在多个处理单元上同时执行。在机器学习训练和推断的背景下,已经出现了几种不同的策略,这主要取决于瓶颈是数据量还是模型本身的大小。
这可以说是加速模型训练最常见的并行化策略。核心思想很简单:如果你有一个大型数据集,将其分割成更小的块,并并行处理每个块。
pmean)或求和(psum)。这个聚合梯度表示在整个全局批次上计算出的梯度。这种方法有效地增加了每步可处理的总批次大小,通常能带来更快的收敛和更好地利用多个加速器。当模型能够轻松放入单个设备的内存中,但数据集非常大时,它最有效。JAX 的 pmap 函数主要用于实现这种单程序多数据(SPMD)模式的数据并行。
数据并行的流程。数据被分片,在模型副本上并行处理,梯度被聚合,参数同步更新。
当模型变得非常大,以至于其参数、激活或中间状态无法放入单个加速器的内存中时,仅靠数据并行是不够的。模型并行通过将模型本身分割到多个设备上解决此问题。
这种策略允许训练极大的模型,但带来了有效分割模型和管理设备间通信开销的复杂性。常见方法包括张量并行(分割单个权重矩阵)和流水线并行。
模型并行的流程。单个数据样本流经分布在多个设备上的模型部分。
流水线并行是一种更精巧的模型并行形式,旨在提高设备利用率。与其让单个批次顺序流经分布在不同设备上的模型部分(导致一些设备空闲而另一些设备工作),流水线并行将批次划分为更小的微批次。
这种方法有助于减少天真的模型并行中固有的空闲时间的“泡沫”,但需要仔细管理依赖关系、调度和状态(例如反向传播所需的激活)。
流水线并行在三个设备/阶段上随时间步长(T1-T5)变化的示意图。微批次(MB1、MB2、MB3)顺序进入流水线,使设备能够同时处理不同的微批次。
尽管所有这些并行策略在现代机器学习中都很重要,但本章将主要关注数据并行。JAX 为此提供了强大的工具,通过其 pmap 转换,该转换将函数映射到多个设备上,自动处理数据分布并提供集体通信机制(例如梯度聚合)。理解数据并行和 pmap 对于扩展 JAX 中大多数标准训练工作负载是基础的。我们将在接下来的章节中查看设备管理、SPMD 执行模型、集体操作和实际实施细节。理解模型并行和流水线并行提供了重要背景,特别是在考虑第 6 章中讨论的极大模型时。
这部分内容有帮助吗?
pmap、SPMD和集合操作。© 2026 ApX Machine Learning用心打造