趋近智
虽然 jax.jit 能为加速器编译您的 Python 函数,但有时需要仔细研究其内部运作,以理解 JAX 究竟在编译什么。调试性能问题或异常行为时,这会特别有用。JAX 使用一种名为 jaxpr(JAX 程序表示)的中间表示来表示从您的 Python 代码派生出的计算图,然后将其传递给 XLA 编译器。理解 jaxpr 有助于识别 JAX 如何追踪函数,并能定位效率低下或重新编译的起因。
您可以将 jaxpr 视为您计算的简化、函数式和明确的表示。当 JAX 追踪您的 Python 函数时(这发生在您首次使用特定输入类型和形状调用 jit 编译的函数时),它不会直接执行 Python 代码。相反,它会用特殊的追踪器对象替换您的输入,并记录在这些追踪器上执行的每个 JAX 原语操作。此追踪过程的结果就是 jaxpr。
jaxpr 的特点包括:
add、sin、dot_general、reduce_sum)应用于输入变量或常量,以生成输出变量。Python 控制流,例如 if 语句或 for 循环(除非使用 JAX 控制流原语,如 lax.cond 或 lax.scan),在追踪过程中会被展开。在典型的模型开发过程中,您通常不会直接与 jaxpr 交互,但 JAX 提供了在需要时检查它们的工具。主要用于此的函数是 jax.make_jaxpr。它接收一个函数和示例输入(如 jit),并返回一个表示 jaxpr 的对象,以及其他信息。
我们来看一个简单例子:
import jax
import jax.numpy as jnp
def example_function(x, y):
a = jnp.sin(x)
b = jnp.cos(y)
return a + b * 2.0
# Define example inputs
key = jax.random.PRNGKey(0)
x_example = jax.random.normal(key, (10,))
y_example = jax.random.normal(key, (10,))
# Generate the jaxpr
jaxpr_obj = jax.make_jaxpr(example_function)(x_example, y_example)
print(jaxpr_obj.jaxpr)
运行此代码将打印 jaxpr 对象。其结构可能类似于这样(细节可能因 JAX 版本而异):
{ lambda ; a:f32[10] b:f32[10]. let
c:f32[10] = sin a
d:f32[10] = cos b
e:f32[10] = mul d 2.0
f:f32[10] = add c e
in (f,) }
我们来分解一下这个打印出的表示:
{ lambda ; ... }:将 jaxpr 定义为一个 lambda 函数。a:f32[10] b:f32[10]:这些是输入变量 (invars) 及其类型 (float32) 和形状 ([10])。它们对应于 example_function 的 x 和 y 参数。let ... in ...:这表示 jaxpr 的主体部分。c:f32[10] = sin a:这是一个等式 (eqn)。它将 sin 原语应用于输入变量 a,并将结果绑定到一个新的中间变量 c,其类型也是 f32[10]。d:f32[10] = cos b:另一个等式,将 cos 原语应用于 b,得到 d。e:f32[10] = mul d 2.0:将乘法原语 (mul) 应用于变量 d 和常量 2.0。从函数环境 (constvars) 捕获的常量也在此处出现。f:f32[10] = add c e:将加法原语 (add) 应用于中间变量 c 和 e,生成 f。in (f,):指定 jaxpr 的输出变量 (outvars),在此例中,仅为 f。请注意,Python 代码的结构是如何被转换为作用于类型化变量的线性原语操作序列的。这种明确、简化的形式比原始 Python 代码更容易让编译器进行分析和优化。
理解 jaxpr 在优化 JAX 代码时有多种帮助:
jax.numpy 使用上的细微差别可能会导致不同的原语。jax.grad 或 jax.vmap 等转换后检查 jaxpr,以查看它们如何修改计算图。这对于调试求导或向量化的行为很有用。jax.jit 会重新编译您的函数。这通常发生在操作序列依赖于参数的值,而不仅仅是它们的形状和类型时(这是基于张量值的 Python 级控制流的常见问题)。比较不同调用生成的 jaxpr 可以突出显示重新编译的原因。如果 eqns 列表差异明显,则表明存在 JAX 必须重新追踪的动态行为。考虑一个使用 jax.lax.cond 的函数,这是 JAX 用于 jit 编译代码中条件执行的原语:
import jax
import jax.numpy as jnp
from jax import lax
def conditional_function(use_sin, x):
# 注意:lax.cond 的 `pred` 必须是标量布尔值
pred = use_sin > 0.5
# 定义真分支和假分支的函数
def true_fun(operand):
return jnp.sin(operand)
def false_fun(operand):
return jnp.cos(operand)
# 使用 lax.cond
return lax.cond(pred, true_fun, false_fun, x)
# 示例输入
x_example = jnp.ones((3,))
pred_true_example = jnp.array(0.7) # 标量值 > 0.5
pred_false_example = jnp.array(0.3) # 标量值 <= 0.5
# 当 pred 为 True 时的 Jaxpr
jaxpr_true = jax.make_jaxpr(conditional_function)(pred_true_example, x_example)
print("Jaxpr (可能执行真分支):")
print(jaxpr_true.jaxpr)
# 当 pred 为 False 时的 Jaxpr
jaxpr_false = jax.make_jaxpr(conditional_function)(pred_false_example, x_example)
print("\nJaxpr (可能执行假分支):")
print(jaxpr_false.jaxpr)
您会注意到生成的 jaxpr 包含一个 cond 原语。重要的是,即使我们提供了具体的布尔值(0.7 对应 True,0.3 对应 False),jaxpr 本身也不仅仅包含一个分支的操作。相反,它包含了 cond 原语,该原语封装了两个分支的逻辑。无论追踪过程中谓词的具体值如何,只要两个分支的输入和输出的类型和形状一致,jaxpr 看起来都会相似。这对于 jit 编译非常重要,因为编译后的代码需要在运行时处理任一分支。
# Jaxpr 的简化表示(结构可能不同)
{ lambda ; a:f32[] b:f32[3]. let
c:bool[] = gt a 0.5
# 真分支 jaxpr 的定义(例如,{ lambda ; x:f32[3]. let y = sin x in (y,) })
# 假分支 jaxpr 的定义(例如,{ lambda ; x:f32[3]. let y = cos x in (y,) })
d:f32[3] = cond c true_branch false_branch b # 传递给 cond 的操作数
in (d,) }
检查 jaxpr 会显示 cond 原语和传递给 XLA 的结构,这证实了 JAX 使用其特定原语正确追踪了条件逻辑。
尽管 jaxpr 提供了很多信息,但它仍是一个中间步骤。它不显示 XLA 执行的最终优化,例如操作符融合(将多个原语合并为一个更高效的内核)或为目标加速器生成的精确低级代码。然而,它提供了对编译器输入的重要视图,使其成为调试性能和了解 JAX 内部机制的必备工具,尤其是在简单应用 @jit 不足以解决问题时。
通过熟悉 jaxpr,您能更清楚地理解 JAX 如何将您的 Python 代码转换为适合高性能编译的形式,使您能够编写更高效、更可预测的 JAX 程序。
这部分内容有帮助吗?
lax.cond,展示它们在jaxpr中的表示。© 2026 ApX Machine Learning用心打造