趋近智
jax.jitjitjit 的常见问题jitgrad 进行自动微分jax.gradgrad的grad)jax.value_and_grad)vmap 实现自动向量化jax.vmapin_axes,out_axes)vmapvmap 与 jit 和 gradvmap的性能考量pmap 在多设备上并行计算jax.pmapin_axes, out_axes)lax.psum、lax.pmean等)pmap 与其他变换结合使用pmap 化的函数pmap 是 JAX 中用于跨多个设备并行化计算的主要工具。尽管 pmap 通过沿单个轴(由 in_axes 和 out_axes 控制)映射数据来实现基本的数据并行,但对于更复杂的硬件拓扑或高级并行策略来说,这种方法通常是不够的。为了有效管理这些复杂场景,JAX 引入了设备网格和命名轴的概念。
将可用的硬件加速器(CPU、GPU、TPU)不只看作一个扁平列表,而是可能组织成一个逻辑网格或网格。这个网格提供了一种组织和寻址设备的方式。例如,如果你有8个TPU,你可以将它们视为:
[设备 0, 设备 1, ..., 设备 7][[设备 0, 设备 1], [设备 2, 设备 3], [设备 4, 设备 5], [设备 6, 设备 7]] 或 [[设备 0, ..., 设备 3], [设备 4, ..., 设备 7]]JAX 可以报告可用设备:
import jax
# 查看可用设备列表
print(jax.devices())
# 示例输出(可能因硬件而异):
# [CpuDevice(id=0)]
# 或者对于多个GPU:
# [GpuDevice(id=0), GpuDevice(id=1), GpuDevice(id=2), GpuDevice(id=3)]
# 或者对于TPU:
# [TpuDevice(id=0), TpuDevice(id=1), ..., TpuDevice(id=7)]
当你使用 pmap 时,它隐式地作用于由你所运行设备形成的网格之上。对于沿一个维度分发数据的标准 pmap 调用,它将设备视为一维网格。
当你需要协调特定设备组的操作,或实现更复杂的并行模式(如模型并行与数据并行结合)时,网格的理念变得更加重要。尽管JAX提供了显式网格管理工具(jax.experimental.maps.Mesh),但即使是对于基本的 pmap 使用,理解逻辑设备排列的思想也很重要,尤其是在使用集体操作时。
4个设备作为一维或二维网格的逻辑排列。
想象一下你的设备排列成一个网格。你如何告诉JAX,这个网格的哪个维度与你使用 pmap 分割的数据对应?你又如何协调沿着网格特定维度排列的设备间的操作?这就是轴名称的作用。
定义 pmap 转换的函数时,你可以使用 axis_name 参数为并行发生的设备网格轴分配一个名称。
import jax
import jax.numpy as jnp
# 假设我们有4个设备
def simple_computation(x):
# 某个操作...
return x * 2
# 将计算映射到设备上,并将设备轴命名为 'data_parallel_axis'
# 我们沿着输入数组 'data' 的第一个轴(轴 0)进行分割
# 跨越命名设备轴。
data = jnp.arange(4 * 5).reshape((4, 5)) # 形状 (4, 5) -> 每个设备一行
# 在这里,pmap 隐式创建一个大小为4的一维网格。
# 'data_parallel_axis' 为这个设备网格的单一维度命名。
# in_axes=0 表示 'data' 的第一个轴映射到这个命名设备轴。
parallel_computation = jax.pmap(simple_computation, axis_name='data_parallel_axis', in_axes=0)
result = parallel_computation(data)
print(result.shape) # 输出: (4, 5) - 形状保持不变,但计算是并行运行的
print(jax.devices()) # 显示使用的设备
# 在4个设备上的示例结果:
# [[ 0 2 4 6 8] # <- 在设备0上计算
# [10 12 14 16 18] # <- 在设备1上计算
# [20 22 24 26 28] # <- 在设备2上计算
# [30 32 34 36 38]] # <- 在设备3上计算
为什么要使用 axis_name?
'data_parallel_axis' 明确表示数据并行。在更复杂的情况下,你可能会使用'model_parallel_axis'。'data_parallel_axis',而无论该轴对应的是4个GPU、8个TPU还是不同的硬件设置。pmap中需要在设备之间进行通信的操作(例如汇总所有设备的结果)需要知道与哪个组的设备进行通信。axis_name指定了这个组。例如,lax.psum(x, axis_name='data_parallel_axis') 将会在所有参与由'data_parallel_axis'标识的并行计算的设备上汇总值x。可以将 in_axes 和 out_axes 看作指定数据维度如何映射到设备网格维度的方式,而 axis_name 则为该 pmap 实例使用的设备网格维度提供一个特定名称。这个名称随后在内部使用,尤其是在协调集体通信时,我们将在下一节中讲解。
即使你只有一个并行维度(纯数据并行中的常见情况),命名轴也是一种良好实践,并且一旦在并行执行期间需要设备之间通信,它就变得必不可少。
这部分内容有帮助吗?
jax.experimental.maps.Mesh等高级并行模式。pmap变换的官方API参考,解释了in_axes、out_axes和axis_name参数及其对并行执行的直接影响。© 2026 ApX Machine Learning用心打造