趋近智
jax.jitjitjit 的常见问题jitgrad 进行自动微分jax.gradgrad的grad)jax.value_and_grad)vmap 实现自动向量化jax.vmapin_axes,out_axes)vmapvmap 与 jit 和 gradvmap的性能考量pmap 在多设备上并行计算jax.pmapin_axes, out_axes)lax.psum、lax.pmean等)pmap 与其他变换结合使用pmap 化的函数JAX 的首要设计理念是以函数变换为中心。不同于可能引入全新数据结构或操作的库,JAX 主要提供用于转换您现有 Python 和 NumPy 代码的工具。可以将 JAX 视为不仅仅是 NumPy 的替代品,而是作为一个作用于您函数之上的元库。
这种方法植根于函数式编程原则。JAX 提倡编写纯函数:其输出仅取决于其明确输入,并且不产生任何副作用(例如修改全局变量或打印到控制台)的函数。尽管这最初可能看起来有局限性,特别是如果您习惯于方法会修改内部状态的面向对象风格,但它是 JAX 实现其性能和功能的核心所在。
为何如此强调纯函数和变换?因为纯函数更容易被自动分析、优化、编译、求导和并行化。当一个函数仅根据其输入可预测地运行时,JAX 可以使用符号(“追踪器”)输入对其执行进行一次追踪,理解操作序列,然后对这个追踪到的计算图应用强大的变换。
您在本课程中会遇到的主要函数变换是:
jax.jit:即时编译。接受一个 Python 函数,并使用 XLA(加速线性代数)对其进行编译,以大幅提升速度,尤其是在 GPU 和 TPU 上。jax.grad:自动微分。接受一个数值函数,并返回一个计算其梯度的函数。对于机器学习中常见的优化任务来说非常重要。jax.vmap:自动向量化(或批处理)。接受一个设计用于处理单个数据点的函数,并将其转换为一个能够高效处理沿数组轴的批次数据的函数,通常无需手动循环。jax.pmap:跨设备并行化。接受一个函数,并将其编译为可在多个设备(例如,多个 GPU 或 TPU 核心)上并行运行,实现单程序多数据(SPMD)并行。这种设计的一个强大特点是可组合性。这些变换并非相互排斥;您可以将它们组合使用。例如,您可以将一个梯度函数进行即时编译,或者对一个本身包含梯度计算的函数进行向量化。
考虑一个简单的 Python 函数:
import jax
import jax.numpy as jnp
def predict(params, inputs):
# 一个简单的线性模型
return jnp.dot(inputs, params['w']) + params['b']
def loss_fn(params, inputs, targets):
predictions = predict(params, inputs)
error = predictions - targets
return jnp.mean(error**2) # 均方误差
使用 JAX 变换,您可以轻松地:
params 的梯度:
grad_fn = jax.grad(loss_fn) # 对 loss_fn 相对于第一个参数 (params) 进行微分
jit_grad_fn = jax.jit(grad_fn)
predict 函数进行向量化以处理一批 inputs:
# 假设 predict 对单个输入起作用,vmap 使其能处理一批输入
# 映射 'inputs' 参数 (轴 0),params 共享 (None)
batched_predict = jax.vmap(predict, in_axes=(None, 0))
如果需要,您甚至可以将 jit、grad 和 vmap 组合使用。这种可组合性使您能够从更简单的纯 Python 函数构建复杂的、高性能的计算流程。
函数式纯度的要求与 jit 等变换的工作方式直接关联。JAX 使用代表数组的形状和类型的抽象值来追踪函数的执行,而非其实际值(除非标记为静态)。副作用在执行过程中与外部交互,通常无法被此追踪过程捕获,导致编译/变换后的函数运行时出现意料之外的行为或错误。我们之前讨论过的 JAX 数组的不可变性,也是这种函数式方法的成果和促进因素。您不会就地修改数组;操作总是返回新数组,从而保持可靠变换所需的纯度。
这种设计理念,将熟悉的类似 NumPy 的 API 与基于函数式编程原则的可组合函数变换相结合,使 JAX 能够高效地支持现代硬件加速器并提供强大的自动微分功能,使其成为数值计算和机器学习研究的重要工具。后续章节将详细介绍每种主要变换(jit、grad、vmap、pmap)。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造