为了有效扩展 JAX 或了解其内部运作,尤其是在与外部代码交互或优化性能时,我们需要查看熟悉的 NumPy 类似 API 的表面之下。JAX 功能的核心是原语的理念。可以将原语视为 JAX 真正识别的基本、原子运算。当你使用 jax.numpy.add 或 jax.numpy.dot 等函数编写代码时,JAX 最终会(或追踪)将它们分解成一系列这些基本原语。例如,包括 add、mul、dot_general、reduce_sum、transpose、conv_general_dilated 等运算,以及 scan_p 或 cond_p 等控制流操作符。原语的意义原语因以下几个原因而有意义:转换目标: jit、grad 和 vmap 等 JAX 转换不直接作用于 Python 代码。它们在原语层面上工作。jit:将 Python 函数追踪为一系列原语,这些原语以称为 jaxpr 的中间格式表示。然后此 jaxpr 由 XLA 编译。grad:依赖于为追踪过程中遇到的每个原语定义了特定的微分规则(向量-雅可比积或 VJP 规则,以及雅可比-向量积或 JVP 规则)。vmap:使用为每个原语定义的批处理规则,以确定如何将运算映射到额外的批处理维度上。编译接口: 原语连接了高层 JAX API 和后端编译器,主要是 XLA(加速线性代数)。JAX 将 jaxpr 表示(一个原语图)转换为 XLA 的高级操作 (HLO) 指令。然后 XLA 进行复杂的优化,并将这些 HLO 指令编译成针对目标硬件(GPU、TPU 或 CPU)的高效机器码。可扩展性: 尽管 JAX 提供了一整套内置原语,但该系统被设计为可扩展的。熟悉原语是向 JAX 添加新的自定义运算的根本。如果你需要集成用 C++ 或 CUDA 编写的专用算法,或者定义一个具有独特微分规则的运算,通常可以通过定义一个新的原语并为其提供必要的规则来实现。转换过程从 Python 函数到执行的过程通常遵循以下步骤,其中原语扮演着重要角色:digraph G { rankdir=LR; node [shape=box, style=rounded, fontname="sans-serif", color="#495057", fillcolor="#e9ecef", style=filled]; edge [fontname="sans-serif", color="#495057"]; py_func [label="Python 函数\n(例如,使用 jax.numpy)"]; jaxpr [label="jaxpr\n(原语序列)", color="#1c7ed6", fillcolor="#a5d8ff"]; xla_hlo [label="XLA HLO\n(硬件无关中间表示)", color="#0ca678", fillcolor="#96f2d7"]; machine_code [label="优化机器码\n(GPU/TPU/CPU)", color="#ae3ec9", fillcolor="#eebefa"]; py_func -> jaxpr [label=" JAX 追踪 \n(\`jit\`, \`grad\`, 等)"]; jaxpr -> xla_hlo [label=" JAX 降级原语 "]; xla_hlo -> machine_code [label=" XLA 编译 "]; }这是使用 JAX 运算的 Python 函数如何转换为可执行代码的视图。原语构成了 jaxpr 表示的主要部分,它是 XLA 编译器的输入。原语与函数区分 jax.numpy.add 这样的 JAX 函数和 add_p 这样的原语会有助于理解。函数 jax.numpy.add 是一个 Python 包装器,它提供了一个熟悉的 NumPy 风格接口。当 JAX 使用此函数追踪代码时,它通常会将相应的原语 add_p 注册到 jaxpr 中。原语是内部表示,承载着转换和编译所需的必要信息。实质上,原语是不可简化的构成单元,JAX 强大的转换和编译系统就是建立在其之上。尽管在日常使用中你可能不会直接与它们交互,但了解它们的作用对于深入优化、调试和扩展 JAX 的能力是不可或缺的,尤其是在我们将在接下来的章节中研究自定义运算时。