趋近智
JAX 使用 XLA 将 Python 函数编译为针对加速器优化的代码。这一优化过程的重要组成部分是算子合并,这是一种 XLA 将 JAX 代码中多个不同的操作合并为一个在加速器上执行的更大计算核的技术。算子合并的工作原理、其性能优势以及观察其效果的方法将得到考察。
理解合并(操作)不仅是为了体会 JAX 速度背后的“奥秘”,而且对于解读性能分析结果以及偶尔构建代码以避免无意中阻止这些优化也同样重要。
算子合并的核心是将处理数据元素或具有生产者-消费者关系的顺序操作合并为一个复合操作。考虑一个简单的操作序列:
import jax
import jax.numpy as jnp
def simple_computation(x, y):
a = jnp.log(x)
b = a + y
c = jnp.exp(b)
return c
# JIT 编译函数
compiled_computation = jax.jit(simple_computation)
# 示例数据
key = jax.random.PRNGKey(0)
x = jax.random.uniform(key, (1000, 1000))
y = jax.random.uniform(key, (1000, 1000))
# 执行
result = compiled_computation(x, y).block_until_ready()
如果没有合并,在 GPU 上执行 simple_computation 可能涉及三个独立的步骤(核启动):
x,计算 log(x),将结果 a 写回内存。a 和 y,计算 a + y,将结果 b 写回内存。b,计算 exp(b),将最终结果 c 写回内存。每个步骤都涉及从加速器的主内存(例如 GPU HBM)读取输入,执行计算,并将输出写回主内存。这种内存传输通常是主要的性能瓶颈。
XLA 的合并优化会分析计算图 (jaxpr),并识别出中间结果 (a 和 b) 仅被下一个操作立即使用。之后,它可以将这些操作合并为一个单独的核。
simple_computation操作在合并前的表示。每个椭圆代表一个潜在的独立核启动,其中包含对其输入/输出的内存读/写。
经过合并后,过程变得更加高效:
x 和 y。exp(log(x) + y)。中间结果 log(x) 和 log(x) + y 保留在加速器核心内快速的片上内存(寄存器或缓存)中。c 一次性写回内存。合并后的表示。元素级操作被合并为一个单独的核,从而最大限度地减少了与主内存之间的数据移动。
算子合并的主要好处是:
你通常不会直接在 JAX 中与合并(操作)进行交互;它是在 jax.jit 编译过程中由 XLA 执行的自动优化。但是,你可以观察到它的影响:
@jit 下运行速度明显快于它们各自执行时间的总和(如果在没有 @jit 的情况下运行,强制中间结果具化为完整的 NumPy 数组),那么合并很可能是主要原因。尽管合并是自动的,但理解它有助于编写 XLA 可以有效优化的 JAX 代码:
jax.numpy 操作保持在一起。XLA 在合并这些操作方面特别有效。合并是 JAX 在加速器上性能的根本。通过减少内存传输和核启动开销,它使以高级 NumPy 类似 API 表示的计算能够在硬件上高效执行,其速度通常接近手动调整的低级代码。认识到它的效果有助于理解性能分析,并体会当你使用 jax.jit 时在幕后发生的优化。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造