趋近智
jax.jitjitjit 的常见问题jitgrad 进行自动微分jax.gradgrad的grad)jax.value_and_grad)vmap 实现自动向量化jax.vmapin_axes,out_axes)vmapvmap 与 jit 和 gradvmap的性能考量pmap 在多设备上并行计算jax.pmapin_axes, out_axes)lax.psum、lax.pmean等)pmap 与其他变换结合使用pmap 化的函数如本章引言所述,在科学计算和机器学习中,对许多不同输入应用相同操作是一种常见需求。可以设想处理一批图像、运行带有不同参数的模拟,或者计算训练集中多个样本的损失贡献。
Python 提供了一种直接方法来处理对多个输入应用相同操作的需求,那就是使用标准 for 循环。例如,如果您有一个作用于单个图像的函数 predict(image),您可能像这样处理一批数据:
def predict_batch_loop(images):
results = []
for image in images:
results.append(predict(image))
return results # Or stack them into an array
虽然直接,但这种方式常遇到性能瓶颈。Python 循环的每次迭代都带有开销,并且内部操作可能无法有效利用现代硬件(如 GPU 或 TPU)的并行处理能力。Python 的解释型特性使得循环比经过优化的编译代码执行的操作慢很多。
向量化提供了一种更高效的替代方案。其核心是,向量化是一种使标量操作(或对小尺寸、固定尺寸输入的操作)适应的技术,使其能在整个数组或数据批次上同时进行。 向量化操作并非在循环中逐个处理元素,而是在多个数据点上同时执行计算,至少从高层代码的角度来看是这样。
考虑将一个常数值加到一组数字上。循环方式会迭代:
# Looping approach
numbers = [1, 2, 3, 4, 5]
results = []
for x in numbers:
results.append(x + 10)
# results: [11, 12, 13, 14, 15]
向量化方式,在 NumPy 这样的库中很常见,看起来像这样:
# Vectorized approach (NumPy example)
import numpy as np
numbers = np.array([1, 2, 3, 4, 5])
results = numbers + 10
# results: array([11, 12, 13, 14, 15])
在这里,numbers + 10 在整个 numpy 数组上应用加法,没有显式的 Python 循环。在内部,NumPy(和类似的库)将此操作委托给高度优化、预编译的代码(通常用 C 或 Fortran 编写),这种代码处理大量数据比 Python 解释器快得多。
循环中顺序处理项目与向量化方式同时处理项目的差异。
向量化的主要原因在于性能。
除了速度优势,向量化代码通常也更简洁易读。numbers + 10 的例子可以说更简单,并且比显式 for 循环更直接地表达了意图。
虽然 NumPy 等库提供了本身就是向量化的函数(如 np.add、np.sin 等),但它们通常要求您以一种与这些函数自然契合的方式组织输入数据并编写代码(例如,确保数组维度与广播兼容)。
JAX 通过 jax.vmap 更进一步。vmap 的思想是让您编写函数逻辑,就像它操作的是单个数据点一样(例如 predict(image))。然后,您可以使用 vmap 自动将该函数转换为一个新函数,该函数可以高效处理批次数据,无需手动重写批处理的核心逻辑或管理循环。它实质上是自动添加了“批次维度”,将计算映射到输入数组的切片上。
理解向量化这个原理,对于掌握 vmap 的工作方式及其在 JAX 生态系统中简化代码和提升性能的实用价值,非常重要。在接下来的章节中,我们将了解如何将 jax.vmap 付诸实践。
这部分内容有帮助吗?
vmap, JAX core contributors, 2024 (The JAX Project) - 解释 jax.vmap 的概念和用法,实现自动批处理和高效执行的官方指南。© 2026 ApX Machine Learning用心打造