趋近智
JAX 提供了强大的抽象方式,例如 jit,能够大幅加速你的 Python 和 NumPy 代码,尤其是在硬件加速器上。然而,要达到最佳性能,往往需要审视其内部机制。仅仅应用 @jit 并不能保证达到最高速度。
本章主要介绍诊断性能瓶颈的方法,以及优化 JAX 程序以适用于 GPU 和 TPU。我们将介绍如何分析 JAX 执行情况,理解跟踪过程中生成的中间 jaxpr 表示,了解 XLA 编译器在优化中的职能,考虑内存布局的影响,尽量减少代价高昂的重新编译,识别算子融合,并使用 JAX 的异步调度机制正确地进行代码基准测试。完成本章后,你将能够有条理地分析并提升你的 JAX 计算的执行速度。
2.1 在 CPU、GPU 和 TPU 上分析 JAX 代码
2.2 理解 JAX 计算图 (jaxpr)
2.3 XLA 编译的作用
2.4 内存布局及其对性能的影响
2.5 避免重复编译
2.6 算子合并与操作优化
2.7 异步调度
2.8 实践:优化数值计算