趋近智
实用实例展示了如何应用 jax.vmap 进行向量化。这些示例包含多种情境,从简单的函数向量化到将其与JAX的其他变换结合使用。
首先,请确保你已安装并导入JAX:
import jax
import jax.numpy as jnp
import timeit
# 检查可用设备 (CPU/GPU/TPU)
print(f"JAX devices: {jax.devices()}")
我们先从一个旨在处理单个输入(例如,一个标量值)的简单函数开始。
# 定义一个处理单个标量的函数
def process_scalar(x):
return x * 2.0 + 1.0
# 测试单个标量输入
scalar_input = 5.0
scalar_output = process_scalar(scalar_input)
print(f"Input: {scalar_input}, Output: {scalar_output}")
# 现在,我们创建一个批次输入
batch_inputs = jnp.arange(1.0, 6.0) # Array([1., 2., 3., 4., 5.])
print(f"Batch Inputs:\n{batch_inputs}")
如果没有 vmap,你可能会手动遍历批次:
# 手动循环(效率较低)
manual_outputs = []
for item in batch_inputs:
manual_outputs.append(process_scalar(item))
manual_outputs = jnp.array(manual_outputs)
print(f"Manual Loop Outputs:\n{manual_outputs}")
这种做法可行,但Python循环通常速度较慢,尤其是在处理大批次或在加速器上运行复杂函数时。现在,我们来使用 vmap。
# 使用 vmap 创建函数的向量化版本
vectorized_process = jax.vmap(process_scalar)
# 将向量化函数应用于批次
vmap_outputs = vectorized_process(batch_inputs)
print(f"vmap Outputs:\n{vmap_outputs}")
# 验证形状
print(f"Batch Input Shape: {batch_inputs.shape}")
print(f"vmap Output Shape: {vmap_outputs.shape}")
请注意 vmap 如何自动处理批次维度。我们编写 process_scalar 来处理单个值,但 vmap(process_scalar) 可以处理一组值。默认情况下,vmap 假定函数应在输入(的)第一个轴(轴0)上进行映射,并生成输出,其中第一个轴对应于被映射的维度。
in_axes 控制映射轴如果你的函数接受多个参数,并且你只想对其中一些进行向量化怎么办?这时 in_axes 参数就派上用场了。in_axes 是一个元组,指定每个输入参数的哪个轴应该被映射。值为 None 表示相应的参数应进行广播,而不是映射。
我们定义一个对向量进行缩放和平移的函数:
# 处理向量、标量缩放因子和标量平移的函数
def scale_and_shift(vector, scale, shift):
return vector * scale + shift
# 单个输入示例
single_vector = jnp.array([1.0, 2.0, 3.0])
single_scale = 2.0
single_shift = 10.0
# 将函数应用于单个输入
single_output = scale_and_shift(single_vector, single_scale, single_shift)
print(f"Single Output:\n{single_output}")
# 现在,创建批次
batch_vectors = jnp.array([[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0],
[7.0, 8.0, 9.0]]) # Shape (3, 3)
batch_scales = jnp.array([0.5, 1.0, 1.5]) # Shape (3,)
# 我们希望对批次中的所有项使用*相同*的平移量
fixed_shift = 100.0
我们希望应用 scale_and_shift,使得:
batch_vectors 中的第一个向量与 batch_scales 中的第一个缩放因子一同处理。fixed_shift。我们使用 in_axes=(0, 0, None) 来指定。这告诉 vmap:
batch_vectors) 的轴0。batch_scales) 的轴0。fixed_shift);而是对其进行广播。# 使用 in_axes 进行向量化
# 映射 'vector' 和 'scale' 的轴0,广播 'shift'
vectorized_scale_shift = jax.vmap(scale_and_shift, in_axes=(0, 0, None))
# 应用向量化函数
batch_outputs = vectorized_scale_shift(batch_vectors, batch_scales, fixed_shift)
print(f"\nBatch Vectors Shape: {batch_vectors.shape}")
print(f"Batch Scales Shape: {batch_scales.shape}")
# fixed_shift 是一个标量
print(f"Batch Outputs:\n{batch_outputs}")
print(f"Batch Output Shape: {batch_outputs.shape}")
输出的形状是 (3, 3),其中第一个维度 3 是 vmap 引入的批次维度,第二个维度 3 来自 scale_and_shift 处理的向量的原始形状。
你也可以指定不同的轴。例如,如果你的批次维度是输入数组的最后一个轴,你会使用 in_axes=(..., -1, ...)。
vmap 与 jit 和 grad 结合JAX的一个显著优点是其变换的可组合性。我们来看看 vmap 如何与 jit(用于编译)和 grad(用于求导)协同工作。
考虑一个代表基本计算的简单函数,它可能是机器学习模型更新的一部分:
# 函数:计算元素级的tanh激活
def simple_activation(params, x):
# params 可以是权重或偏差,这里只是一个标量缩放因子
return jnp.tanh(params * x)
# 单个输入示例
single_x = jnp.linspace(-2.0, 2.0, 5) # Shape (5,)
single_param = 1.5
# 计算单个输入的输出
single_output = simple_activation(single_param, single_x)
print(f"Single Input x:\n{single_x}")
print(f"Single Output:\n{single_output}")
# --- 使用 jit ---
# 编译函数
jit_activation = jax.jit(simple_activation)
jit_output = jit_activation(single_param, single_x)
# 确保结果匹配(忽略潜在的浮点精度差异)
assert jnp.allclose(single_output, jit_output)
print("\nJIT 输出与原始输出匹配。")
# --- 使用 grad ---
# 获取相对于第一个参数(params)的梯度函数
grad_activation = jax.grad(simple_activation, argnums=0)
# 计算单个输入的梯度
single_grad = grad_activation(single_param, single_x)
print(f"Gradient w.r.t params (single input):\n{single_grad}")
# --- 将 vmap 与 jit 和 grad 结合 ---
# 现在,为 x 创建批处理输入,保持 params 固定
batch_x = jnp.stack([single_x, single_x * 0.5, single_x * -1.0]) # Shape (3, 5)
print(f"\nBatch Input x Shape: {batch_x.shape}")
# 选项1:先对原始函数进行 vmap,然后 jit
vectorized_activation = jax.vmap(simple_activation, in_axes=(None, 0)) # 广播 param,映射 x
jit_vectorized_activation = jax.jit(vectorized_activation)
batch_output_1 = jit_vectorized_activation(single_param, batch_x)
print(f"Batch Output Shape (vmap -> jit): {batch_output_1.shape}")
# 选项2:先对原始函数进行 jit,然后 vmap
# (这通常是首选,因为 jit 首先看到原始结构)
vectorized_jit_activation = jax.vmap(jit_activation, in_axes=(None, 0))
batch_output_2 = vectorized_jit_activation(single_param, batch_x)
print(f"Batch Output Shape (jit -> vmap): {batch_output_2.shape}")
assert jnp.allclose(batch_output_1, batch_output_2)
print("两种组合顺序的输出都匹配。")
# 现在,我们获取批处理梯度
# 我们希望获取 batch_x 中每个项相对于 'params' 的梯度
# 我们在 grad *之后* 应用 vmap
# grad 返回与 'params' 匹配的梯度结构(此处为标量)
# vmap 添加一个与被映射输入 'x' 对应的批次维度
vectorized_grad = jax.vmap(grad_activation, in_axes=(None, 0)) # 广播 param,映射 x
batch_grads = vectorized_grad(single_param, batch_x)
print(f"\nBatch Gradients w.r.t params:\n{batch_grads}")
print(f"Batch Gradients Shape: {batch_grads.shape}")
# 我们也可以对向量化梯度函数进行 JIT,以提高性能
jit_vectorized_grad = jax.jit(vectorized_grad)
start_time = timeit.default_timer()
batch_grads_jit = jit_vectorized_grad(single_param, batch_x).block_until_ready()
duration = timeit.default_timer() - start_time
print(f"JITed 批次梯度计算时间: {duration:.6f}s")
assert jnp.allclose(batch_grads, batch_grads_jit)
本实例展示了 vmap 的出色集成能力。你可以对函数进行 vmap,然后对结果进行 jit,或者对 grad 后的函数进行 vmap。这种组合方式允许你编写清晰的单实例逻辑,然后使用 vmap 高效地将其应用于批次,使用 grad 进行求导,并使用 jit 进行编译。
尝试应用 vmap 解决以下问题:
给定两组二维点:
points_a = jnp.array([[0, 0], [1, 1]])
points_b = jnp.array([[2, 2], [3, 3], [4, 4]])
编写一个函数 pairwise_distance(a, b),计算单个点 a 和单个点 b 之间的欧几里得距离。然后,使用 vmap(可能嵌套)计算一个矩阵,其中元素 (i, j) 是 points_a[i] 和 points_b[j] 之间的距离。
预期输出形状应为 (2, 3)。
提示: 你可能需要一个 vmap 来处理 points_a 的遍历,以及另一个(嵌套的)vmap 来处理对 a 中每个点遍历 points_b。请仔细考虑每一层级的 in_axes 参数。
# --- 解决方案草图 ---
def euclidean_distance(p1, p2):
# 计算两点之间的距离(形状为 (2,))
return jnp.sqrt(jnp.sum((p1 - p2)**2))
points_a = jnp.array([[0., 0.], [1., 1.]]) # Shape (2, 2)
points_b = jnp.array([[2., 2.], [3., 3.], [4., 4.]]) # Shape (3, 2)
# 目标:计算一个 (2, 3) 距离矩阵 dist(points_a[i], points_b[j])
# 提示1:首先对固定点 a 的 points_b 进行向量化
# vmap_over_b = jax.vmap(euclidean_distance, in_axes=(None, 0))
# 尝试调用 vmap_over_b(points_a[0], points_b) -> 预期形状 (3,)
# 提示2:现在对前一步骤中的 points_a 进行向量化
# vmap_over_a_and_b = jax.vmap(vmap_over_b, in_axes=(0, None))
# 尝试调用 vmap_over_a_and_b(points_a, points_b) -> 预期形状 (2, 3)
# --- 完整解决方案 ---
# 使用嵌套 vmap 计算点对距离
@jax.jit # 对最终计算进行 JIT 以提高效率
def compute_pairwise_distances(arr_a, arr_b):
# 内层循环的 vmap(对固定 a 遍历 b)
vmap_dist_b = jax.vmap(euclidean_distance, in_axes=(None, 0))
# 外层循环的 vmap(遍历 a,对每个元素应用 vmap_dist_b)
vmap_dist_a_b = jax.vmap(vmap_dist_b, in_axes=(0, None))
return vmap_dist_a_b(arr_a, arr_b)
distance_matrix = compute_pairwise_distances(points_a, points_b)
print("\n--- 点对距离问题 ---")
print(f"点 A:\n{points_a}")
print(f"点 B:\n{points_b}")
print(f"点对距离矩阵:\n{distance_matrix}")
print(f"输出形状: {distance_matrix.shape}")
# 预期输出:
# [[2.828427 4.2426405 5.656854 ]
# [1.4142135 2.828427 4.2426405]]
这些实际实例演示了 jax.vmap 的灵活性和实用性。通过将为单个数据点编写的函数转换为能在批次上高效运行的函数,vmap 简化了代码,通常能提高性能,并与JAX的其他变换(如 jit 和 grad)很好地集成。随着你更多地使用JAX,尤其是在机器学习场景中,vmap 将成为处理批次数据不可或缺的工具。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造