趋近智
尽管使用 pmap 的数据并行,结合梯度累积和检查点等技术,能够扩展到更大的数据集和有效的批处理大小,但当模型的参数和中间激活变得过大,以至于无法放入单个加速器设备的内存中时,一个显著的限制便会显现。即使使用混合精度,顶尖模型自身的庞大尺寸也可能超出现有最大 GPU 或 TPU 的容量。
当模型的单个副本无法在一台设备上运行时,您需要将模型本身划分到多个设备上。这是模型并行背后的主要思想。与数据并行不同,数据并行是在不同的数据片上运行相同的模型,而模型并行涉及拆分模型的组件(层甚至层的一部分)并将它们分配给不同的加速器。数据在流经划分后的模型时在这些设备之间流动。
实现模型并行有两种主要策略:
张量并行侧重于将单一层内(或操作)的计算拆分到多个设备上。这对于具有巨大权重矩阵的层(例如 Transformers 中常见的大型线性层)尤其有效。
考虑一个大型矩阵乘法,它是神经网络中的一个基本操作:Y=XW。如果权重矩阵 W 对于单个设备来说过大,我们可以将其划分到多个设备上。例如,W 可以按列拆分到两台设备上,W=[W1,W2]。输入 X 被发送到这两台设备。每台设备计算一部分输出:设备 1 上计算 Y1=XW1,设备 2 上计算 Y2=XW2。最终结果 Y 通过连接部分结果得到:Y=[Y1,Y2]。
张量并行应用于矩阵乘法 Y=XW 的简化视图。权重矩阵 W 被拆分为 W1 和 W2,计算 XW1 和 XW2 在不同的设备上进行。输入 X 被复制,部分结果 Y1 和 Y2 被收集。
或者,W 可以按行拆分,这需要通信(例如,一次 all-reduce 操作)来求和部分结果。更复杂的操作,例如注意力机制,也可以通过这种方式并行化。
难点:
在 JAX 中,张量并行通常涉及在分配给并行层的设备子集上使用 pmap,并常与应用于 pmap 中定义的特定 axis_name 维度的 jax.lax.psum 或 jax.lax.all_gather 等集合通信原语结合使用。基于 JAX 构建的库可能会提供常见张量并行模式的抽象。
流水线并行采用不同的方法:它在层之间划分模型,将顺序的阶段或层块分配给不同的设备。一个阶段(在一台设备上)的输出激活成为下一个阶段(在另一台设备上)的输入。
设想一个具有四层(L1, L2, L3, L4)的模型,分布在四台设备(D1, D2, D3, D4)上。
一个简单的实现会顺序处理一个批次:D1 计算 L1,将激活发送到 D2;D2 计算 L2,发送到 D3;D3 计算 L3,发送到 D4;D4 计算 L4。在此过程中,任何时候只有一台设备处于活动状态,导致大量空闲时间,通常称为“流水线气泡”。
简单的流水线并行:层被分配给顺序设备。激活在设备之间传递。
为了提高效率并减少气泡,使用了微批处理。输入批次被拆分为更小的微批次。一旦设备 1 完成处理第一个微批次,它就会将激活发送到设备 2,并立即开始处理第二个微批次。这允许多台设备同时处理不同的微批次,从而填充流水线。
带有微批处理(MB1、MB2 等)的流水线并行。设备同时处理不同的微批次,减少空闲时间(由“-”表示)。初始和最终的气泡仍然存在,但会被许多微批次分摊。
难点:
在 JAX 中,流水线并行通常涉及使用 jax.device_put 或后端特定的放置机制,手动将计算的不同部分放置到特定设备上。阶段之间的通信通常使用点对点传输,这可能会由库抽象或根据设置(单主机与多主机)需要更低级别的原语。管理微批处理和调度通常由用户或更高级别的框架完成。
实践中,复杂的大规模训练设置通常结合这些策略:
直接在 JAX 中仅使用 pmap、lax.p* 集合操作和设备放置等原语实现这些高级模型并行策略是可行的,但要求很高。它需要透彻理解设备拓扑、通信模式和精心编排。因此,实践者通常依赖 JAX 生态系统中更高级别的库(例如 Flax 的扩展或为大型模型构建的专用库),这些库为常见的张量和流水线并行模式提供抽象,使其更易于应用。
尽管本节提供了一个概述,但实际实现高效模型并行通常涉及使用这些专用库,或投入大量工程精力来构建针对特定模型架构和硬件设置的定制解决方案。然而,理解这些策略对于在仅数据并行不足时选择正确的方法和工具是必要的。
这部分内容有帮助吗?
pmap和集体通信原语进行分布式计算和自动并行化的基本机制。© 2026 ApX Machine Learning用心打造