趋近智
JAX和XLA优化计算的方式值得了解。然而,性能并不仅仅与执行的操作相关;它还取决于数据从内存中访问的效率。多维数组数据在线性内存中的物理排列方式,即内存布局,会显著影响性能,特别是在GPU和TPU等拥有专用内存访问硬件的加速器上。
考虑一个简单的二维数组(一个矩阵):
该数组包含6个元素。在计算机内存中,内存本质上是地址的线性序列,这些元素必须按顺序排列。为此,有两种常见的约定方式:
1, 2, 3, 4, 5, 6。第一行的元素是连续的,接着是第二行的元素。这是NumPy使用的默认布局,JAX通常也采用此方式。1, 4, 2, 5, 3, 6。第一列的元素是连续的,接着是第二列的元素。2x3矩阵采用行主序和列主序布局的线性内存表示。地址表示内存中的排列序列。
对于更高维的数组,该原理同样适用。在行主序中,当您遍历连续内存位置时,最后的索引变化最快。在列主序中,最前的索引变化最快。
内存布局对性能的影响源于硬件获取数据的方式:
A[i, j]、A[i, j+1]、A[i, j+2] 等)访问元素通常会实现合并访问,这是非常理想的。沿一列访问元素(A[i, j]、A[i+1, j] 等)通常会导致步进式、低效的访问。默认情况下,JAX为其数组使用行主序布局,这与NumPy一致。然而,JAX在其编译代码中不会直接操作这些数组。在JIT编译期间,XLA会接管处理。
XLA执行复杂的优化,包括布局分配。根据操作序列和目标硬件(CPU、GPU、TPU),XLA可能会内部改变中间数组的布局以提高性能。例如,XLA可能会将数组转置为在GPU或TPU上进行后续矩阵乘法时更高效的布局,即使原始输入是行主序。
这意味着您通常无法(也通常不需要)明确控制JIT编译的JAX函数中每个中间数组的精确内存布局。XLA会做出这些决定。
虽然XLA自动处理大部分布局优化,但对于追求最高性能的高级用户来说,了解这些情况仍然很重要:
jnp.einsum 或 jnp.matmul 等操作中某些数组形状和收缩维度可能更高效地映射到硬件。总而言之,内存布局是加速器性能的一个重要因素。尽管JAX和XLA在许多情况下通过执行自动布局优化来抽象化直接控制,但理解行主序与列主序布局的原理,以及它们与GPU(合并访问)和TPU(分块/MXU效率)等硬件的配合方式,有助于您理解性能分析结果,并在优化核心代码时,设计出更适合硬件的算法。
这部分内容有帮助吗?
ndarray对象的内存布局,包括C-contiguous(行主序)和Fortran-contiguous(列主序)存储顺序。© 2026 ApX Machine Learning用心打造