趋近智
jax.pmap 的实际应用涉及在可用设备上分发一个简单计算。pmap 遵循单程序多数据 (SPMD) 原则:相同的 Python 函数代码在所有参与设备上运行,但每个设备获取输入数据的各自部分。
对于这些示例,JAX 将自动使用所有可用设备(如果未找到 GPU/TPU,则使用 CPU 核心;或使用所有可用 GPU/TPU)。您可以使用 jax.device_count() 查看 JAX 识别的设备数量。
import jax
import jax.numpy as jnp
import numpy as np
# 检查可用设备
num_devices = jax.device_count()
print(f"Number of devices available: {num_devices}")
# 获取设备列表
devices = jax.devices()
print(f"Available devices: {devices}")
如果您在只有 CPU 的机器上运行此代码,jax.device_count() 可能会返回 1 或 JAX 配置为视为独立设备的 CPU 核心数量(通常默认是 1,除非另行配置)。如果您有多个 GPU 或在 TPU Pod 切片上,它将报告可用加速器的数量。当 num_devices > 1 时,pmap 展现其真正的作用。即使只有一个设备,代码也会运行,但不会进行跨设备并行执行。
让我们定义一个我们想并行运行的简单函数:将数组乘以 2。
# 一个逐元素应用的简单函数
def scale_by_two(x):
print("Compiling and running scale_by_two...") # 看看 JAX 何时进行追踪/运行
return x * 2
现在,让我们创建一些数据。对于 pmap,输入数据需要跨设备分片(分割)。最简单的方法是确保输入数组有一个前导轴,其大小等于设备数量。沿此轴的每个切片将发送到相应的设备。
# 创建可跨设备分割的数据
# 让我们创建一个前导维度等于设备数量的数组
data_size_per_device = 4
total_data_size = num_devices * data_size_per_device
global_data = jnp.arange(total_data_size)
# 重塑数据,使第一个维度与设备数量匹配
# 每一行将发送到一个设备
sharded_data = global_data.reshape((num_devices, data_size_per_device))
print(f"Global data shape: {global_data.shape}")
print(f"Sharded data shape: {sharded_data.shape}")
print(f"Sharded data:\n{sharded_data}")
接下来,我们使用 jax.pmap 转换我们的函数。默认情况下,pmap 假设输入参数 (parameter)的第一个轴(轴 0)应映射到各个设备。
# 对我们的函数应用 pmap
parallel_scale = jax.pmap(scale_by_two)
# 在分片数据上运行并行函数
result = parallel_scale(sharded_data)
# 让我们查看结果
print(f"\nOutput shape: {result.shape}")
print(f"Output type: {type(result)}")
print(f"Output content:\n{result}")
您应该注意到:
result 与输入 sharded_data 具有相同的形状。它是一个 JAX ShardedDeviceArray(或根据 JAX 版本而定的类似类型),表示数据分布在各个设备上。scale_by_two 在每个设备上应用于相应输入切片的结果。例如,如果 num_devices=2,result 的第一行 [0*2, 1*2, 2*2, 3*2] 是在设备 0 上计算的,第二行 [4*2, 5*2, 6*2, 7*2] 是在设备 1 上计算的。in_axes 指定轴如果我们的函数接受多个参数 (parameter),我们只希望对其中一些参数进行并行处理,或者在不同的轴上进行并行处理,该怎么办?这就是 pmap 的 in_axes 参数的作用。
in_axes 指定每个输入参数的哪个轴应该被映射。它可以是一个整数(轴索引)、None(将参数广播到所有设备),或者一个与参数数量匹配的元组/列表。
让我们修改我们的函数,使其接受两个参数:一个要分片的数组 x 和一个要广播的标量 y。
# 接受两个参数的函数
def scale_and_add(x, y):
print(f"Compiling and running scale_and_add on device {jax.process_index()}...")
return x * 2 + y
# x 的数据保持不变
# sharded_data = jnp.arange(num_devices * data_size_per_device).reshape((num_devices, data_size_per_device))
# 要广播的标量值
scalar_y = jnp.float32(100.0)
# 应用 pmap,指定每个参数如何处理
# x (sharded_data):映射轴 0
# y (scalar_y):广播(将相同值发送到所有设备)
parallel_scale_add = jax.pmap(scale_and_add, in_axes=(0, None))
# 运行并行函数
result_add = parallel_scale_add(sharded_data, scalar_y)
print(f"\nInput x shape: {sharded_data.shape}")
print(f"Input y: {scalar_y}")
print(f"\nOutput shape: {result_add.shape}")
print(f"Output content:\n{result_add}")
这里,in_axes=(0, None) 告诉 pmap:
x),沿轴 0 分割它,并将每个切片发送到一个设备。y),将 整个 值 scalar_y 发送到 每个 设备。结果 result_add 将再次具有形状 (num_devices, data_size_per_device),其中每个元素在其各自设备上计算为 x_slice * 2 + 100.0。
lax.psum 进行集合操作并行计算中一个常见需求是汇集来自不同设备的结果。例如,计算每个设备上独立计算的值的总和或平均值。JAX 在 jax.lax 中为此目的提供了集合原语。这些操作仅在经过 pmap 处理的函数 内部 有效。
让我们计算所有设备上所有已处理元素的总和。
# 包含集合操作(跨设备求和)的函数
def scale_and_sum(x):
print(f"Compiling and running scale_and_sum on device {jax.process_index()}...")
scaled_x = x * 2
# 首先 *在每个设备上* 计算和
local_sum = jnp.sum(scaled_x)
# 现在,将所有设备的本地和求和
global_sum = jax.lax.psum(local_sum, axis_name='devices')
# 注意:现在每个设备都持有相同的“global_sum”
return scaled_x, global_sum
# 我们需要告知 pmap psum 中使用的轴名称
# 名称 'devices' 是任意的,但在 pmap 和 psum 之间必须匹配
parallel_scale_sum = jax.pmap(scale_and_sum, axis_name='devices')
# 运行并行函数
sharded_data = jnp.arange(num_devices * data_size_per_device).reshape((num_devices, data_size_per_device))
result_scaled, result_sum = parallel_scale_sum(sharded_data)
# 手动计算预期全局和以进行验证
expected_sum = jnp.sum(jnp.arange(total_data_size) * 2)
print(f"\nOutput scaled data shape: {result_scaled.shape}")
print(f"Output scaled data:\n{result_scaled}") # 仍然是分片的
print(f"\nOutput sum shape: {result_sum.shape}") # 应该被复制
print(f"Output sum (replicated across devices):\n{result_sum}")
print(f"Expected global sum: {expected_sum}")
# 验证所有设备上的和是否相同
# (为演示目的,从第一个设备访问数据)
print(f"Sum computed by pmap (device 0): {result_sum[0]}")
本示例中的要点:
jax.pmap 中定义了 axis_name='devices'。此名称逻辑上将参与并行计算的设备分组。jax.lax.psum(local_sum, axis_name='devices') 执行求和归约。它取每个设备上独立计算的 local_sum,并将这些值在名称为 'devices' 的所有设备组中求和。psum 返回的 global_sum 在所有设备上都被 复制。请注意 result_sum 的形状是 (num_devices,),并且其所有元素都相同,持有真实的合计总和。result_scaled 数组仍然像以前一样是分片的。这些动手实践示例说明了使用 jax.pmap 进行数据并行的核心机制:包括准备分片数据、使用 in_axes 映射函数参数 (parameter)、并行执行函数以及使用 psum 等集合操作进行跨设备通信。这为在多个加速器上扩展 JAX 计算(特别是大型机器学习 (machine learning)模型训练)奠定了基础。
这部分内容有帮助吗?
© 2026 ApX Machine LearningAI伦理与透明度•