趋近智
jax.jitjitjit 的常见问题jitgrad 进行自动微分jax.gradgrad的grad)jax.value_and_grad)vmap 实现自动向量化jax.vmapin_axes,out_axes)vmapvmap 与 jit 和 gradvmap的性能考量pmap 在多设备上并行计算jax.pmapin_axes, out_axes)lax.psum、lax.pmean等)pmap 与其他变换结合使用pmap 化的函数你希望向量化的函数通常操作多个输入。例如,你可能需要计算两个不同批次中对应向量的点积,或者将批次中的每个向量乘以同一个矩阵。jax.vmap 通过其 in_axes 参数提供对批处理如何应用于每个参数的细致控制。
in_axes 参数:指定映射方式默认情况下,vmap 假定你希望映射函数接收的第一个参数的第一个轴(轴 0)。如果你的函数接受多个参数,或者你想映射除了轴 0 以外的轴,你需要使用 in_axes。
in_axes 参数接受一个元组(或列表),其长度与被向量化的函数的定位参数数量相匹配。in_axes 元组的每个元素都指定了 vmap 如何处理对应的定位参数:
i:这告诉 vmap 映射对应参数的轴 i。这个轴的尺寸将成为批次维度。如果多个参数具有整数 in_axes,它们指定轴的尺寸必须匹配。None:这告诉 vmap 不要映射对应参数。相反,该参数被视为常量,并广播到其他参数的已映射维度上。我们来查看常见情况。
设想一个将向量按标量因子缩放的函数:
import jax
import jax.numpy as jnp
def scale_vector(vector, scalar):
"""将向量按标量缩放。"""
return vector * scalar
# 单个输入示例
vector = jnp.arange(3.)
scalar = 2.0
print(f"单次缩放: {scale_vector(vector, scalar)}")
# Expected Output: Single scale: [0. 2. 4.]
现在,假设你有一批向量,并且你想用相同的标量来缩放每个向量。你希望映射第一个参数(向量批次),但保持第二个参数(标量)不变。这可以通过使用 in_axes=(0, None) 实现。
# 向量批次(3 个大小为 3 的向量)
batch_of_vectors = jnp.arange(9.).reshape((3, 3))
single_scalar = 2.0
# 向量化函数:映射向量的轴 0,广播标量
vectorized_scale = jax.vmap(scale_vector, in_axes=(0, None))
# 应用向量化函数
batched_result = vectorized_scale(batch_of_vectors, single_scalar)
print("\n向量批次:")
print(batch_of_vectors)
print(f"\n单个标量: {single_scalar}")
print("\nvmap(scale_vector, in_axes=(0, None)) 后的结果:")
print(batched_result)
# Expected Output:
# Batch of vectors:
# [[0. 1. 2.]
# [3. 4. 5.]
# [6. 7. 8.]]
#
# Single scalar: 2.0
#
# Result after vmap(scale_vector, in_axes=(0, None)):
# [[ 0. 2. 4.]
# [ 6. 8. 10.]
# [12. 14. 16.]]
这里,in_axes=(0, None) 指示 vmap:
batch_of_vectors),映射其轴 0。这个轴的尺寸(3)决定了批次大小。single_scalar),不映射(None)。将其视为固定值,并适当地广播它以匹配由映射第一个参数引入的批次维度。一个非常常见的用例是在两个数据批次之间逐元素应用操作。例如,将两个批次中对应的向量相加。
def add_vectors(vec1, vec2):
"""将两个向量相加。"""
return vec1 + vec2
# 两个向量批次(每个批次有 3 个大小为 3 的向量)
batch1 = jnp.arange(9.).reshape((3, 3))
batch2 = jnp.ones((3, 3)) * 10
print("批次 1:")
print(batch1)
print("\n批次 2:")
print(batch2)
# 向量化:映射两个参数的轴 0
vectorized_add = jax.vmap(add_vectors, in_axes=(0, 0))
# 应用向量化函数
batched_sum = vectorized_add(batch1, batch2)
print("\nvmap(add_vectors, in_axes=(0, 0)) 后的结果:")
print(batched_sum)
# Expected Output:
# Batch 1:
# [[0. 1. 2.]
# [3. 4. 5.]
# [6. 7. 8.]]
#
# Batch 2:
# [[10. 10. 10.]
# [10. 10. 10.]
# [10. 10. 10.]]
#
# Result after vmap(add_vectors, in_axes=(0, 0)):
# [[10. 11. 12.]
# [13. 14. 15.]
# [16. 17. 18.]]
使用 in_axes=(0, 0),vmap 会从 batch1 沿轴 0 获取第 k 个切片,从 batch2 沿轴 0 获取第 k 个切片,并将它们作为 vec1 和 vec2 传递给原始的 add_vectors 函数,对每个 k 直到批次大小。批次大小由两个输入的轴 0 尺寸决定,它们必须相同(本例中为 3)。
in_axes我们可以可视化 in_axes 如何指导映射。考虑将 vmap(f, in_axes=(0, None, 0)) 应用于输入为 xs、y_val 和 zs 的 f(x, y, z)。
该图显示
vmap从xs(轴 0)中获取切片x₀, x₁, ...,并从zs(轴 0)中获取切片z₀, z₁, ...,同时将相同的y_val传递给函数f的每次并行调用。结果随后堆叠起来形成输出批次。
虽然不那么常见,但你可以通过在 in_axes 中提供不同的整数来沿不同轴映射不同的参数。例如,in_axes=(0, 1) 会映射第一个参数的轴 0 和第二个参数的轴 1。这需要仔细考虑函数的逻辑和输入数组的形状。
考虑一个将滤镜(一维向量)应用于矩阵的每行的函数:
def apply_filter(row_vector, filter_vector):
"""应用滤镜(逐元素乘法)。假设形状匹配。"""
return row_vector * filter_vector
# 一个矩阵(例如,3 行,4 列)
matrix = jnp.arange(12.).reshape((3, 4))
# 一个单滤镜向量(大小 4)
filter_v = jnp.array([1., 10., 100., 1000.])
print("矩阵:")
print(matrix)
print(f"\n滤镜: {filter_v}")
# 向量化:映射矩阵的轴 0(行),广播滤镜
# 这类似于场景 1
vectorized_apply_rows = jax.vmap(apply_filter, in_axes=(0, None))
result_rows = vectorized_apply_rows(matrix, filter_v)
print("\n将滤镜应用于每行 (in_axes=(0, None)):")
print(result_rows)
# Expected Output:
# Applying filter to each row (in_axes=(0, None)):
# [[ 0. 10. 200. 3000.]
# [ 4. 50. 600. 7000.]
# [ 8. 90. 1000. 11000.]]
# 现在,假设我们想将滤镜应用于每*列*
# 该函数期望一个向量,所以我们需要考虑转置
# 让我们定义列的滤镜(大小 3)
column_filters = jnp.array([1., 10., 100.])
# 我们希望映射矩阵的轴 1(列)和滤镜的轴 0
# 为了使 apply_filter 工作(逐元素乘积),输入需要兼容的形状。
# `vmap` 根据 `in_axes` 处理切片。
# `matrix` 形状 (3, 4),映射轴 1 -> 切片形状为 (3,) - 即列
# `column_filters` 形状 (3,),映射轴 0 -> 切片形状为 () - 即滤镜中的标量
# 这对于列和滤镜的逐元素乘法来说不太正确。
# 我们稍作修改:假设我们有一批*列操作*,
# 每个操作对整个列使用不同的标量乘数。
def scale_column(column_vector, scalar):
return column_vector * scalar
# 映射矩阵的轴 1(列),映射标量的轴 0
# 注意:列数(4)必须与标量数(轴 0 的大小)匹配
scalars = jnp.array([1., 10., 100., 1000.])
if matrix.shape[1] == scalars.shape[0]:
vectorized_apply_cols = jax.vmap(scale_column, in_axes=(1, 0))
result_cols = vectorized_apply_cols(matrix, scalars)
print("\n用不同的标量缩放每列 (in_axes=(1, 0)):")
print(result_cols)
# Expected Output:
# Scaling each column by a different scalar (in_axes=(1, 0)):
# [[ 0. 10. 200. 3000.]
# [ 4. 50. 600. 7000.]
# [ 8. 90. 1000. 11000.]]
# 注意:此处结果看起来与行示例相似,因为选择了相同的值,
# 但机制不同:每*列*都是独立缩放的。
else:
print("\n由于形状不匹配,跳过列示例。")
在这个最后的示例中,in_axes=(1, 0) 告诉 vmap:
matrix 参数,沿轴 1(列)迭代切片。每个切片的形状为 (3,)。scalars 参数,沿轴 0 迭代切片。每个切片都是一个标量。scale_column。掌握 in_axes 对于有效运用 vmap 非常重要,它让你能够以最少的代码修改对现有函数进行数据批次的向量化,即使在处理需要不同批处理或广播方式的多个输入时也是如此。
这部分内容有帮助吗?
vmap 的最新、准确信息,包括 in_axes 的用法和示例。vmap 如何融入其数值计算的可组合函数变换系统。© 2026 ApX Machine Learning用心打造