趋近智
jax.jit通过编译Python函数,JAX能获得显著的性能提升。这方面的主要工具是 jax.jit 变换。可以将 jit(即时编译)看作一种方法,它能够将操作JAX数组的标准Python函数转换为针对您硬件(CPU、GPU或TPU)高度优化、融合的操作序列。
jax.jit应用 jit 变换主要有两种方式:
@jax.jit。jit,并将您的函数作为参数传递。这会返回一个该函数的新编译版本。我们来看一个简单的例子。假设我们有一个函数,它执行一些数值运算:
import jax
import jax.numpy as jnp
import time
# 一个包含一些数值计算的函数
def complex_computation(x, weight, bias):
y = jnp.dot(x, weight) + bias
z = jnp.tanh(y)
return jnp.mean(z)
# 生成一些随机数据
key = jax.random.PRNGKey(0)
x_data = jax.random.normal(key, (1000, 500))
weight_data = jax.random.normal(key, (500, 200))
bias_data = jax.random.normal(key, (200,))
# --- 方法一:将 jit 作为装饰器使用 ---
@jax.jit
def compiled_computation_decorator(x, weight, bias):
y = jnp.dot(x, weight) + bias
z = jnp.tanh(y)
return jnp.mean(z)
# --- 方法二:将 jit 作为函数使用 ---
compiled_computation_functional = jax.jit(complex_computation)
# --- 测量执行时间 ---
# 测量原始 Python 函数的时间
# 运行一次以避免与 JAX 无关的初始开销
_ = complex_computation(x_data, weight_data, bias_data).block_until_ready()
start_time = time.time()
result_original = complex_computation(x_data, weight_data, bias_data).block_until_ready()
end_time = time.time()
print(f"原始函数时间: {end_time - start_time:.6f} seconds")
# 测量 JIT 编译函数的时间(装饰器版本)
# 首次调用包含编译时间
start_time_compile = time.time()
result_compiled_decorator = compiled_computation_decorator(x_data, weight_data, bias_data).block_until_ready()
end_time_compile = time.time()
print(f"编译函数(装饰器)首次调用(含编译): {end_time_compile - start_time_compile:.6f} seconds")
# 第二次调用使用缓存的编译代码
start_time_cached = time.time()
result_compiled_decorator_cached = compiled_computation_decorator(x_data, weight_data, bias_data).block_until_ready()
end_time_cached = time.time()
print(f"编译函数(装饰器)第二次调用(已缓存): {end_time_cached - start_time_cached:.6f} seconds")
# 验证结果是否相同(在浮点容差范围内)
print(f"结果匹配: {jnp.allclose(result_original, result_compiled_decorator)}")
# 测量函数式版本的时间(首次编译后应相似)
_ = compiled_computation_functional(x_data, weight_data, bias_data).block_until_ready() # 编译
start_time_func = time.time()
result_compiled_functional = compiled_computation_functional(x_data, weight_data, bias_data).block_until_ready()
end_time_func = time.time()
print(f"编译函数(函数式)后续调用: {end_time_func - start_time_func:.6f} seconds")
关于计时的一个重要提示: 请注意,在每个我们要计时的函数调用后都使用了 .block_until_ready()。JAX 默认使用异步调度,这意味着操作会被排队,但可能不会立即完成。block_until_ready() 确保计算在记录结束时间之前完成,从而提供准确的测量结果。
您会观察到以下模式:
jit 编译函数的首次调用(compiled_computation_decorator 或 compiled_computation_functional)通常比原始函数 慢。这是因为 JAX 需要执行一个重要的步骤,称为追踪(我们将在下一节讨论),然后使用 XLA(加速线性代数编译器)编译追踪到的操作。装饰器(@jax.jit)和函数式(jax.jit(fn))方式的区别主要是风格上的。在定义函数时,装饰器因其可读性而常被偏好。当您想编译从其他地方获得的函数时(例如,一个库函数,尽管许多 JAX 库函数在适当情况下已在内部进行了 JIT 编译),函数式形式很有用。
简而言之,jax.jit 提供了一种直接的方式来请求编译您对性能有要求的数值函数。通过了解何时以及如何应用它,您可以让 JAX 代码大幅提速。下一节将说明实现这一点的追踪机制。
这部分内容有帮助吗?
jax.jit 的用法、行为和底层机制,包括追踪和异步执行。jax.jit 在内的核心概念,用于性能优化、追踪以及与硬件加速器的交互。© 2026 ApX Machine Learning用心打造