趋近智
jax.jit 通过跟踪您的 Python 函数来发挥其作用。它使用表示输入的占位符对象(称为跟踪器)执行函数一次。这些跟踪器会记录在其上执行的所有操作,构建一个中间表示(Jaxpr),然后由 XLA 编译成优化过的机器码。
这种跟踪机制导致了一个重要的区别:函数中涉及的一些值是跟踪的,而另一些则必须被视为静态的。理解这一区别对于正确高效地使用 jit 具有重要意义。
大多数情况下,JIT 编译函数的输入将是 JAX 数组(或包含 JAX 数组的结构)。当 jit 跟踪函数时,这些数组输入会被跟踪器对象替换。
jnp.dot、jnp.sin 等)时,它不会立即执行计算。相反,它将该操作记录为计算图(Jaxpr)的一部分。实际的数值计算发生在编译之后,当您使用具体数据调用编译后的函数时。可以把它想象成绘制一个计算的蓝图。蓝图规定了材料的尺寸和类型(形状和 dtype)以及要遵循的步骤(操作),但在您开始构建(执行编译后的代码)之前,不需要实际的物理材料(具体值)。
相比之下,静态值是在编译时(跟踪期间)已知且固定的值。它们在编译后的代码中被视为常量。
if 语句的条件评估为静态 True,则只有该 if 块内的代码才会被包含在编译后的 Jaxpr 中。当标准 Python 控制流(如 if 语句或 for 循环)依赖于 JAX 默认作为跟踪值处理的输入值时,核心问题就会出现。
请看这个简单的函数:
import jax
import jax.numpy as jnp
def conditional_double(x, threshold):
# Python if 语句依赖于 x 的值
if x > threshold:
return x * 2
else:
return x
我们来尝试对其进行 JIT 编译:
jitted_conditional_double = jax.jit(conditional_double)
# 这很可能会引发错误!
try:
result = jitted_conditional_double(jnp.array(5.0), threshold=0.0)
print(result)
except Exception as e:
print(f"错误: {e}")
您会遇到一个类似这样的错误:ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected...。
为什么会出现错误? 在跟踪期间,x 被跟踪器替换。Python 的 if 语句尝试评估 tracer > threshold。但跟踪器没有具体的数值;它只知道其形状和 dtype。它无法在编译时为 Python 的 if 语句生成所需的单个布尔值 True 或 False。Python 需要一个具体的布尔值来立即决定跟踪哪个分支,但 JAX 只有一个占位符。
static_argnums 和 static_argnames为了处理控制流或其他函数逻辑在跟踪期间必须依赖于参数具体值的情况,JAX 提供了将参数标记 (token)为静态的方法。
您可以使用 static_argnums 或 static_argnames 参数告诉 jit 将特定参数视为静态:
static_argnums:提供一个整数元组,指定应为静态的参数的位置索引。static_argnames:提供一个字符串元组,指定应为静态的参数(位置或关键字)的名称。让我们修正之前的例子。假设我们知道 threshold 在多次调用中很可能保持不变,或者我们需要 if 语句像标准 Python 逻辑一样工作。我们可以将 threshold 标记为静态:
# 使用 static_argnums(threshold 是第 1 个参数,索引为 1)
jitted_conditional_double_nums = jax.jit(conditional_double, static_argnums=(1,))
# 使用 static_argnames
jitted_conditional_double_names = jax.jit(conditional_double, static_argnames=('threshold',))
# 现在这些可以工作了:
result1 = jitted_conditional_double_nums(jnp.array(5.0), threshold=0.0)
print(f"结果 (static_argnums): {result1}") # 输出:结果 (static_argnums): 10.0
result2 = jitted_conditional_double_names(jnp.array(-2.0), threshold=0.0)
print(f"结果 (static_argnames): {result2}") # 输出:结果 (static_argnames): -2.0
现在,当 jit 跟踪函数时,它知道 threshold 是静态的。在跟踪期间,它会替换为 threshold 提供的实际值(例如 0.0)。Python 的 if 语句随后可以评估 tracer > 0.0。虽然这仍然涉及跟踪器,但 JAX 有时可以通过使用专门的原始操作(如 lax.cond)将条件逻辑嵌入 (embedding)到编译后的代码中来处理涉及常量和跟踪器的比较。然而,依赖于跟踪值的更复杂的 Python 逻辑仍然会失败。将控制 Python 逻辑的值设为静态,确保 Python 解释器可以在跟踪期间执行控制流。
将参数 (parameter)标记 (token)为静态允许在 JIT 编译函数中使用更多标准 Python 结构,但这带来了一个重要的性能考量:重新编译。
jit 的优势。# 这里的 threshold 是静态的
jitted_func = jax.jit(conditional_double, static_argnames=('threshold',))
print("第一次调用 (threshold=0.0):")
_ = jitted_func(jnp.array(5.0), threshold=0.0) # 为 threshold=0.0 编译
print("第二次调用 (threshold=0.0):")
_ = jitted_func(jnp.array(1.0), threshold=0.0) # 重用已编译代码
print("第三次调用 (threshold=10.0):")
_ = jitted_func(jnp.array(5.0), threshold=10.0) # *** 为 threshold=10.0 重新编译 ***
print("第四次调用 (threshold=10.0):")
_ = jitted_func(jnp.array(1.0), threshold=10.0) # 重用为 threshold=10.0 编译的代码
JAX 根据 Python 函数对象的同一性以及静态参数值(连同输入形状/dtypes)缓存编译后的函数。
您通常在以下情况需要静态参数:
if、for、while 循环,其条件或迭代直接依赖于参数的具体值(不仅仅是其形状)。一般指导原则: 对于经常变化的数值数据(JAX 数组),优先使用跟踪参数(默认)。对于控制计算结构或在跟踪期间 Python 运行时逻辑所需、且在调用之间不常变化的值,请谨慎使用静态参数。
如果您发现自己需要基于跟踪值的动态控制流,可以了解 JAX 的结构化控制流原始操作,如 lax.cond、lax.scan 和 lax.switch,它们被设计为可跟踪和可编译的。我们将在稍后提及这些,它们提供了一种表达与 JIT 兼容的动态计算图的方法。
理解静态值和跟踪值之间的区别对于调试 jit 问题以及通过最小化重新编译来优化性能很重要。通过仔细考虑哪些参数需要是静态的,您可以有效地加速您的 JAX 代码。
这部分内容有帮助吗?
© 2026 ApX Machine LearningAI伦理与透明度•