现代机器学习的规模经常超出单个计算设备的能力。无论是处理庞大的数据集还是包含数十亿参数的模型,分散计算负荷都变得必不可少。并行处理的基本原理用于应对这些大规模问题,为理解 JAX 的分布式计算功能,尤其是 pmap,提供依据。并行处理,本质上涉及将计算任务划分为更小、独立或半独立的部分,这些部分可以在多个处理单元上同时执行。在机器学习训练和推断的背景下,已经出现了几种不同的策略,这主要取决于瓶颈是数据量还是模型本身的大小。数据并行这可以说是加速模型训练最常见的并行化策略。核心思想很简单:如果你有一个大型数据集,将其分割成更小的块,并并行处理每个块。模型复制: 相同的模型架构及其当前参数被复制到多个计算设备(例如,GPU或TPU核心)上。数据分片: 全局训练数据批次被划分为小批次,每个设备接收一个独有的小批次。并行正向/反向传播: 每个设备独立地使用其模型副本在其本地小批次上执行正向和反向传播计算,从而计算本地梯度。梯度聚合: 每个设备上计算的梯度会在所有设备间进行聚合。常见的聚合方法包括平均(pmean)或求和(psum)。这个聚合梯度表示在整个全局批次上计算出的梯度。参数更新: 所有设备上的模型参数都会使用聚合梯度进行更新。这确保了所有模型副本保持同步。这种方法有效地增加了每步可处理的总批次大小,通常能带来更快的收敛和更好地利用多个加速器。当模型能够轻松放入单个设备的内存中,但数据集非常大时,它最有效。JAX 的 pmap 函数主要用于实现这种单程序多数据(SPMD)模式的数据并行。digraph G { rankdir=TB; node [shape=box, style=filled, fillcolor="#a5d8ff", fontname="sans-serif"]; edge [fontname="sans-serif"]; subgraph cluster_devices { label = "设备 (GPUs/TPUs)"; style=filled; color="#e9ecef"; node [shape=record, fillcolor="#ffec99"]; Device1 [label="{设备 1 | 模型副本 | 数据分片 1 | 本地梯度 1}"]; Device2 [label="{设备 2 | 模型副本 | 数据分片 2 | 本地梯度 2}"]; DeviceN [label="{设备 N | 模型副本 | 数据分片 N | 本地梯度 N}"]; } Data [label="全局数据批次", shape=cylinder, fillcolor="#96f2d7"]; Aggregator [label="梯度聚合\n(例如,pmean)", shape=circle, fillcolor="#ffc9c9"]; Update [label="参数更新", shape=diamond, fillcolor="#bac8ff"]; Data -> {Device1, Device2, DeviceN} [label="分片数据"]; {Device1, Device2, DeviceN} -> Aggregator [label="发送梯度"]; Aggregator -> Update [label="聚合梯度"]; Update -> {Device1, Device2, DeviceN} [label="更新参数", style=dashed]; }数据并行的流程。数据被分片,在模型副本上并行处理,梯度被聚合,参数同步更新。模型并行当模型变得非常大,以至于其参数、激活或中间状态无法放入单个加速器的内存中时,仅靠数据并行是不够的。模型并行通过将模型本身分割到多个设备上解决此问题。模型分割: 模型的不同部分(例如,层或层内的特定操作)被分配到不同的设备上。数据流: 输入数据按顺序流经模型的各个部分。一个设备上的计算输出可能成为位于另一个设备上的模型下一个部分的输入。设备间通信: 设备之间通常需要大量的通信来在前向传播期间传递激活,在反向传播期间传递梯度。这种策略允许训练极大的模型,但带来了有效分割模型和管理设备间通信开销的复杂性。常见方法包括张量并行(分割单个权重矩阵)和流水线并行。digraph G { rankdir=LR; node [shape=box, style=filled, fillcolor="#a5d8ff", fontname="sans-serif"]; edge [fontname="sans-serif"]; subgraph cluster_model { label = "模型跨设备分割"; style=filled; color="#e9ecef"; node [shape=box, fillcolor="#ffec99"]; Device1 [label="设备 1\n(模型部分 1)"]; Device2 [label="设备 2\n(模型部分 2)"]; DeviceN [label="设备 N\n(模型部分 N)"]; } Input [label="输入数据", shape=ellipse, fillcolor="#96f2d7"]; Output [label="输出", shape=ellipse, fillcolor="#96f2d7"]; Input -> Device1; Device1 -> Device2 [label="激活/梯度"]; Device2 -> DeviceN [label="激活/梯度"]; DeviceN -> Output; }模型并行的流程。单个数据样本流经分布在多个设备上的模型部分。流水线并行流水线并行是一种更精巧的模型并行形式,旨在提高设备利用率。与其让单个批次顺序流经分布在不同设备上的模型部分(导致一些设备空闲而另一些设备工作),流水线并行将批次划分为更小的微批次。分阶段: 模型被划分为顺序的阶段,每个阶段分配给一个设备或一组设备。微批处理: 输入数据批次被分割成更小的微批次。流水线执行: 设备以交错的方式处理微批次。当设备 1 完成处理阶段 1 的微批次 1 并开始处理微批次 2 时,设备 2 就可以开始处理阶段 2 的微批次 1。这产生了流水线效应,使更多设备同时保持繁忙。这种方法有助于减少天真的模型并行中固有的空闲时间的“泡沫”,但需要仔细管理依赖关系、调度和状态(例如反向传播所需的激活)。digraph G { rankdir=TD; node [shape=record, style=filled, fontname="sans-serif"]; edge [fontname="sans-serif"]; subgraph cluster_dev1 { label="设备 1 (阶段 1)"; color="#e9ecef"; fillcolor="#e9ecef"; style=filled; D1_MB1 [label="MB1", fillcolor="#ffec99"]; D1_MB2 [label="MB2", fillcolor="#ffe066"]; D1_MB3 [label="MB3", fillcolor="#ffd43b"]; } subgraph cluster_dev2 { label="设备 2 (阶段 2)"; color="#e9ecef"; fillcolor="#e9ecef"; style=filled; D2_MB1 [label="MB1", fillcolor="#ffec99"]; D2_MB2 [label="MB2", fillcolor="#ffe066"]; D3_MB3 [label="MB3", fillcolor="#ffd43b"];} subgraph cluster_dev3 { label="设备 3 (阶段 3)"; color="#e9ecef"; fillcolor="#e9ecef"; style=filled; D3_MB1 [label="MB1", fillcolor="#ffec99"]; D3_MB2 [label="MB2", fillcolor="#ffe066"]; D3_MB3 [label="MB3", fillcolor="#ffd43b"]; } // Time steps invisible nodes node [shape=point, width=0]; T0 -> T1 -> T2 -> T3 -> T4 -> T5 [style=invis]; { rank=same; T0; } { rank=same; T1; D1_MB1; } { rank=same; T2; D1_MB2; D2_MB1;} { rank=same; T3; D1_MB3; D2_MB2; D3_MB1; } { rank=same; T4; D2_MB3; D3_MB2; } { rank=same; T5; D3_MB3; } // Arrows indicating data flow D1_MB1 -> D2_MB1; D1_MB2 -> D2_MB2; D1_MB3 -> D2_MB3; D2_MB1 -> D3_MB1; D2_MB2 -> D3_MB2; D2_MB3 -> D3_MB3; }流水线并行在三个设备/阶段上随时间步长(T1-T5)变化的示意图。微批次(MB1、MB2、MB3)顺序进入流水线,使设备能够同时处理不同的微批次。重点关注使用 pmap 的数据并行尽管所有这些并行策略在现代机器学习中都很重要,但本章将主要关注数据并行。JAX 为此提供了强大的工具,通过其 pmap 转换,该转换将函数映射到多个设备上,自动处理数据分布并提供集体通信机制(例如梯度聚合)。理解数据并行和 pmap 对于扩展 JAX 中大多数标准训练工作负载是基础的。我们将在接下来的章节中查看设备管理、SPMD 执行模型、集体操作和实际实施细节。理解模型并行和流水线并行提供了重要背景,特别是在考虑第 6 章中讨论的极大模型时。