JAX 的 jit 装饰器是一个强大工具,通过使用 XLA 编译 Python 函数来加速计算。然而,这个编译过程并非没有开销。JAX 首先会追踪 Python 函数,以生成一个中间表示,即 jaxpr,它记录了基本操作的序列。然后,XLA 会将这个 jaxpr 编译成针对目标硬件(CPU、GPU、TPU)以及在追踪时遇到的输入参数的形状和类型优化的代码。首次调用带有 @jit 装饰的函数时,如果参数的形状和类型是特定组合,JAX 会执行追踪和编译,并将生成的优化可执行文件存储在缓存中。后续使用相同参数形状和类型的调用可以重用缓存的可执行文件,从而明显加快速度,因为它们避免了追踪和编译的开销。然而,如果调用函数时,参数的特性与之前遇到的不同,并且这种不同导致 jaxpr 发生变化,JAX 就必须重新追踪并重新编译。这种重复编译会增加大量开销,可能会抵消 jit 的优势,特别是当它频繁发生时(例如,在循环内部)。了解并减少这些重复编译事件对于获得最高性能非常重要。为什么会发生重复编译JAX 的追踪机制通过用特殊的追踪器对象替换函数的参数来工作。这些追踪器记录了所执行的 JAX 基本操作的序列。主要原因是追踪依赖于输入参数的以下方面:参数形状: 如果您第一次使用形状为 (10, 5) 的数组调用函数,之后又使用形状为 (20, 5) 的数组,jaxpr 中的操作序列或它们的维度可能会变化。这会触发重新追踪和重复编译。参数数据类型(dtypes): 使用 float32 参数后又使用 float64 参数调用,也可能触发重复编译,因为可能需要不同的基本操作实现。参数类型: 在不同调用中对相同参数位置混合使用整数和浮点数等 Python 类型可能会导致重复编译。JAX 需要知道参数是 JAX 数组(被追踪)还是标准 Python 类型(被视为静态)。依赖值的 Python 控制流: 如果您的函数使用标准的 Python if 语句或 for 循环,其条件或迭代依赖于被追踪参数(JAX 数组)的值,那么在追踪过程中执行的路径可能会改变。不同的路径意味着不同的 jaxpr,从而导致重复编译。函数参数的结构: 如果您将函数作为参数传递(在高阶函数中常见),那么在不同调用之间改变函数本身会需要重复编译。闭包和全局变量: 如果 JIT 编译的函数引用了 Python 变量或访问了全局变量,并且这些变量的值影响了计算图的结构(例如,决定中间数组的大小),那么在不同调用之间这些值的变化会触发重复编译。请看这个简单例子:import jax import jax.numpy as jnp import time @jax.jit def example_function(x, size_param): # Python 控制流依赖于 size_param 的值 if size_param > 5: return jnp.sum(x * 2.0) else: return jnp.sum(x + 1.0) key = jax.random.PRNGKey(0) data = jax.random.normal(key, (10,)) # 第一次调用:size_param = 3。追踪并编译版本 1。 print("First call...") start_time = time.time() result1 = example_function(data, 3) result1.block_until_ready() # 确保执行完成 print(f"Result: {result1}, Time: {time.time() - start_time:.4f}s") # 第二次调用(已缓存):size_param = 3。使用缓存的版本 1。速度快。 print("\nSecond call (cached)...") start_time = time.time() result2 = example_function(data, 3) result2.block_until_ready() print(f"Result: {result2}, Time: {time.time() - start_time:.4f}s") # 第三次调用(重复编译):size_param = 7。触发重新追踪和重新编译版本 2。速度慢。 print("\nThird call (recompilation)...") start_time = time.time() result3 = example_function(data, 7) result3.block_until_ready() print(f"Result: {result3}, Time: {time.time() - start_time:.4f}s") # 第四次调用(已缓存):size_param = 7。使用缓存的版本 2。速度快。 print("\nFourth call (cached)...") start_time = time.time() result4 = example_function(data, 7) result4.block_until_ready() print(f"Result: {result4}, Time: {time.time() - start_time:.4f}s")您会发现第一次和第三次调用明显更慢,这是由于编译开销。第二次和第四次调用重用了缓存的可执行文件,速度快得多。第三次调用发生重复编译,是因为 size_param 的值改变了在追踪过程中采用的 Python 控制流路径。减少重复编译的策略虽然某些重复编译是预期的(例如,当形状确实改变时),但频繁的、非预期的重复编译会影响性能。以下是减轻这种情况的重要策略:1. 使用静态参数(static_argnums / static_argnames)当函数的参数值影响计算图结构(例如形状、决定层的超参数,或 Python 控制流中使用的值),但参数本身不打算被追踪(例如,它是一个 Python int 或 bool)时,您可以将其标记为“静态”。JAX 会将静态参数视为编译时常量。它会为这些静态参数的每种唯一值组合追踪并编译一个专用版本的函数。这避免了仅静态参数改变时发生重复编译,因为 JAX 可以查找预编译的专用版本。您可以使用 jax.jit 的 static_argnums(按索引)或 static_argnames(按名称,通常更清晰)参数来指定静态参数:import jax import jax.numpy as jnp import time # 将 'size_param'(索引 1)标记为静态 @jax.jit(static_argnums=(1,)) def example_function_static(x, size_param): print(f"Compiling for size_param = {size_param}") # 查看何时发生编译 if size_param > 5: return jnp.sum(x * 2.0) else: return jnp.sum(x + 1.0) key = jax.random.PRNGKey(0) data = jax.random.normal(key, (10,)) # 第一次调用:size_param = 3。编译专用版本 1。 print("First call...") start_time = time.time() result1 = example_function_static(data, 3) result1.block_until_ready() print(f"Result: {result1}, Time: {time.time() - start_time:.4f}s") # 第二次调用(已缓存):size_param = 3。使用缓存的专用版本 1。速度快。 print("\nSecond call (cached)...") start_time = time.time() result2 = example_function_static(data, 3) result2.block_until_ready() print(f"Result: {result2}, Time: {time.time() - start_time:.4f}s") # 第三次调用(新静态值):size_param = 7。编译专用版本 2。 print("\nThird call (new static value)...") start_time = time.time() result3 = example_function_static(data, 7) result3.block_until_ready() print(f"Result: {result3}, Time: {time.time() - start_time:.4f}s") # 第四次调用(已缓存):size_param = 7。使用缓存的专用版本 2。速度快。 print("\nFourth call (cached)...") start_time = time.time() result4 = example_function_static(data, 7) result4.block_until_ready() print(f"Result: {result4}, Time: {time.time() - start_time:.4f}s")现在,对于 size_param=3 只会编译一次,对于 size_param=7 也只编译一次。后续使用这些值的调用会重用相应的缓存专用版本。注意: 明智地使用静态参数。如果一个静态参数可以取很多不同的值,您最终可能会编译出大量专用版本,从而增加编译时间和编译缓存的内存使用。它最适合用于取值范围有限且控制图结构的参数。2. 确保输入类型和形状一致尽可能地,尝试使用形状和数据类型一致的数组来调用您的 JIT 编译函数。填充: 如果处理可变长度序列,请考虑将它们填充到固定的最大长度,并使用掩码技术(参见第一章)来忽略填充的元素。类型转换: 如果输入可能带有不同类型,请在调用 JIT 函数之前显式地将其转换为所需的数据类型。3. 使用 JAX 控制流基本操作将依赖于被追踪的值的 Python if、for 和 while 循环替换为它们的 JAX 对应项:jax.lax.cond、jax.lax.scan 和 jax.lax.while_loop。这些基本操作已集成到 JAX 追踪系统。它们将分支或循环逻辑嵌入到已编译的 XLA 图中,而不是根据 Python 执行路径创建不同的图。这避免了当控制流的值改变时发生重复编译。import jax import jax.numpy as jnp import time # 使用 lax.cond 代替 Python 的 if @jax.jit def example_function_lax(x, size_param_val): # size_param_val 必须是 0 维数组或可被 cond 追踪的 Python 标量 # 这里我们假设它来自数据或已适当传递 pred = size_param_val > 5 # 为真分支和假分支定义函数 def true_fun(operand): return jnp.sum(operand * 2.0) def false_fun(operand): return jnp.sum(operand + 1.0) # lax.cond 根据 pred 选择执行哪个函数 return jax.lax.cond(pred, true_fun, false_fun, x) key = jax.random.PRNGKey(0) data = jax.random.normal(key, (10,)) # 例子:传递一个可被 cond 追踪的 JAX 标量 # 通常,size_param_val 会从其他 JAX 数组计算得出 size_val_3 = jnp.array(3) size_val_7 = jnp.array(7) # 第一次调用:只编译一次。 print("First call (lax.cond)...") start_time = time.time() result1 = example_function_lax(data, size_val_3) result1.block_until_ready() print(f"Result: {result1}, Time: {time.time() - start_time:.4f}s") # 第二次调用:使用缓存版本。速度快。没有重复编译。 print("\nSecond call (lax.cond, cached)...") start_time = time.time() result2 = example_function_lax(data, size_val_7) result2.block_until_ready() print(f"Result: {result2}, Time: {time.time() - start_time:.4f}s")使用 lax.cond,只发生一次编译,并在 XLA 图中处理两个分支。4. 仔细管理闭包和函数参数避免引用其值频繁变化并影响图结构的变量。而是将这些值作为常规(可能静态的)参数传递。如果将函数作为参数传递(例如,激活函数),请确保每次都传递相同的函数对象。如果需要绑定参数,使用 functools.partial 有时可以帮助创建稳定的函数对象。在顶层定义函数通常可以避免动态地在循环内部定义它们所带来的问题。通过积极找出代码中重复编译的原因(通过性能分析,下一节会讲到)并应用这些策略,特别是 static_argnums 和 JAX 控制流基本操作,您可以大幅减少编译开销,并确保您的 @jit 装饰函数在首次编译后持续快速运行。