趋近智
JAX 的 jit 装饰器是一个强大工具,通过使用 XLA 编译 Python 函数来加速计算。然而,这个编译过程并非没有开销。JAX 首先会追踪 Python 函数,以生成一个中间表示,即 jaxpr,它记录了基本操作的序列。然后,XLA 会将这个 jaxpr 编译成针对目标硬件(CPU、GPU、TPU)以及在追踪时遇到的输入参数的形状和类型优化的代码。
首次调用带有 @jit 装饰的函数时,如果参数的形状和类型是特定组合,JAX 会执行追踪和编译,并将生成的优化可执行文件存储在缓存中。后续使用相同参数形状和类型的调用可以重用缓存的可执行文件,从而明显加快速度,因为它们避免了追踪和编译的开销。
然而,如果调用函数时,参数的特性与之前遇到的不同,并且这种不同导致 jaxpr 发生变化,JAX 就必须重新追踪并重新编译。这种重复编译会增加大量开销,可能会抵消 jit 的优势,特别是当它频繁发生时(例如,在循环内部)。了解并减少这些重复编译事件对于获得最高性能非常重要。
JAX 的追踪机制通过用特殊的追踪器对象替换函数的参数来工作。这些追踪器记录了所执行的 JAX 基本操作的序列。主要原因是追踪依赖于输入参数的以下方面:
(10, 5) 的数组调用函数,之后又使用形状为 (20, 5) 的数组,jaxpr 中的操作序列或它们的维度可能会变化。这会触发重新追踪和重复编译。float32 参数后又使用 float64 参数调用,也可能触发重复编译,因为可能需要不同的基本操作实现。if 语句或 for 循环,其条件或迭代依赖于被追踪参数(JAX 数组)的值,那么在追踪过程中执行的路径可能会改变。不同的路径意味着不同的 jaxpr,从而导致重复编译。请看这个简单例子:
import jax
import jax.numpy as jnp
import time
@jax.jit
def example_function(x, size_param):
# Python 控制流依赖于 size_param 的值
if size_param > 5:
return jnp.sum(x * 2.0)
else:
return jnp.sum(x + 1.0)
key = jax.random.PRNGKey(0)
data = jax.random.normal(key, (10,))
# 第一次调用:size_param = 3。追踪并编译版本 1。
print("First call...")
start_time = time.time()
result1 = example_function(data, 3)
result1.block_until_ready() # 确保执行完成
print(f"Result: {result1}, Time: {time.time() - start_time:.4f}s")
# 第二次调用(已缓存):size_param = 3。使用缓存的版本 1。速度快。
print("\nSecond call (cached)...")
start_time = time.time()
result2 = example_function(data, 3)
result2.block_until_ready()
print(f"Result: {result2}, Time: {time.time() - start_time:.4f}s")
# 第三次调用(重复编译):size_param = 7。触发重新追踪和重新编译版本 2。速度慢。
print("\nThird call (recompilation)...")
start_time = time.time()
result3 = example_function(data, 7)
result3.block_until_ready()
print(f"Result: {result3}, Time: {time.time() - start_time:.4f}s")
# 第四次调用(已缓存):size_param = 7。使用缓存的版本 2。速度快。
print("\nFourth call (cached)...")
start_time = time.time()
result4 = example_function(data, 7)
result4.block_until_ready()
print(f"Result: {result4}, Time: {time.time() - start_time:.4f}s")
您会发现第一次和第三次调用明显更慢,这是由于编译开销。第二次和第四次调用重用了缓存的可执行文件,速度快得多。第三次调用发生重复编译,是因为 size_param 的值改变了在追踪过程中采用的 Python 控制流路径。
虽然某些重复编译是预期的(例如,当形状确实改变时),但频繁的、非预期的重复编译会影响性能。以下是减轻这种情况的重要策略:
static_argnums / static_argnames)当函数的参数值影响计算图结构(例如形状、决定层的超参数,或 Python 控制流中使用的值),但参数本身不打算被追踪(例如,它是一个 Python int 或 bool)时,您可以将其标记为“静态”。
JAX 会将静态参数视为编译时常量。它会为这些静态参数的每种唯一值组合追踪并编译一个专用版本的函数。这避免了仅静态参数改变时发生重复编译,因为 JAX 可以查找预编译的专用版本。
您可以使用 jax.jit 的 static_argnums(按索引)或 static_argnames(按名称,通常更清晰)参数来指定静态参数:
import jax
import jax.numpy as jnp
import time
# 将 'size_param'(索引 1)标记为静态
@jax.jit(static_argnums=(1,))
def example_function_static(x, size_param):
print(f"Compiling for size_param = {size_param}") # 查看何时发生编译
if size_param > 5:
return jnp.sum(x * 2.0)
else:
return jnp.sum(x + 1.0)
key = jax.random.PRNGKey(0)
data = jax.random.normal(key, (10,))
# 第一次调用:size_param = 3。编译专用版本 1。
print("First call...")
start_time = time.time()
result1 = example_function_static(data, 3)
result1.block_until_ready()
print(f"Result: {result1}, Time: {time.time() - start_time:.4f}s")
# 第二次调用(已缓存):size_param = 3。使用缓存的专用版本 1。速度快。
print("\nSecond call (cached)...")
start_time = time.time()
result2 = example_function_static(data, 3)
result2.block_until_ready()
print(f"Result: {result2}, Time: {time.time() - start_time:.4f}s")
# 第三次调用(新静态值):size_param = 7。编译专用版本 2。
print("\nThird call (new static value)...")
start_time = time.time()
result3 = example_function_static(data, 7)
result3.block_until_ready()
print(f"Result: {result3}, Time: {time.time() - start_time:.4f}s")
# 第四次调用(已缓存):size_param = 7。使用缓存的专用版本 2。速度快。
print("\nFourth call (cached)...")
start_time = time.time()
result4 = example_function_static(data, 7)
result4.block_until_ready()
print(f"Result: {result4}, Time: {time.time() - start_time:.4f}s")
现在,对于 size_param=3 只会编译一次,对于 size_param=7 也只编译一次。后续使用这些值的调用会重用相应的缓存专用版本。
注意: 明智地使用静态参数。如果一个静态参数可以取很多不同的值,您最终可能会编译出大量专用版本,从而增加编译时间和编译缓存的内存使用。它最适合用于取值范围有限且控制图结构的参数。
尽可能地,尝试使用形状和数据类型一致的数组来调用您的 JIT 编译函数。
将依赖于被追踪的值的 Python if、for 和 while 循环替换为它们的 JAX 对应项:jax.lax.cond、jax.lax.scan 和 jax.lax.while_loop。这些基本操作已集成到 JAX 追踪系统。它们将分支或循环逻辑嵌入到已编译的 XLA 图中,而不是根据 Python 执行路径创建不同的图。这避免了当控制流的值改变时发生重复编译。
import jax
import jax.numpy as jnp
import time
# 使用 lax.cond 代替 Python 的 if
@jax.jit
def example_function_lax(x, size_param_val):
# size_param_val 必须是 0 维数组或可被 cond 追踪的 Python 标量
# 这里我们假设它来自数据或已适当传递
pred = size_param_val > 5
# 为真分支和假分支定义函数
def true_fun(operand):
return jnp.sum(operand * 2.0)
def false_fun(operand):
return jnp.sum(operand + 1.0)
# lax.cond 根据 pred 选择执行哪个函数
return jax.lax.cond(pred, true_fun, false_fun, x)
key = jax.random.PRNGKey(0)
data = jax.random.normal(key, (10,))
# 例子:传递一个可被 cond 追踪的 JAX 标量
# 通常,size_param_val 会从其他 JAX 数组计算得出
size_val_3 = jnp.array(3)
size_val_7 = jnp.array(7)
# 第一次调用:只编译一次。
print("First call (lax.cond)...")
start_time = time.time()
result1 = example_function_lax(data, size_val_3)
result1.block_until_ready()
print(f"Result: {result1}, Time: {time.time() - start_time:.4f}s")
# 第二次调用:使用缓存版本。速度快。没有重复编译。
print("\nSecond call (lax.cond, cached)...")
start_time = time.time()
result2 = example_function_lax(data, size_val_7)
result2.block_until_ready()
print(f"Result: {result2}, Time: {time.time() - start_time:.4f}s")
使用 lax.cond,只发生一次编译,并在 XLA 图中处理两个分支。
functools.partial 有时可以帮助创建稳定的函数对象。在顶层定义函数通常可以避免动态地在循环内部定义它们所带来的问题。通过积极找出代码中重复编译的原因(通过性能分析,下一节会讲到)并应用这些策略,特别是 static_argnums 和 JAX 控制流基本操作,您可以大幅减少编译开销,并确保您的 @jit 装饰函数在首次编译后持续快速运行。
这部分内容有帮助吗?
jax.jit, Rosalia Schneider, Vladimir Mikulik, 2024 (Read the Docs) - 官方 JAX 指南,解释了 JIT 编译、追踪、jaxpr、重新编译的原因,以及使用 static_argnums 进行优化的方法。© 2026 ApX Machine Learning用心打造