JAX和XLA优化计算的方式值得了解。然而,性能并不仅仅与执行的操作相关;它还取决于数据从内存中访问的效率。多维数组数据在线性内存中的物理排列方式,即内存布局,会显著影响性能,特别是在GPU和TPU等拥有专用内存访问硬件的加速器上。行主序与列主序布局考虑一个简单的二维数组(一个矩阵): $$ A = \begin{pmatrix} 1 & 2 & 3 \ 4 & 5 & 6 \end{pmatrix} $$ 该数组包含6个元素。在计算机内存中,内存本质上是地址的线性序列,这些元素必须按顺序排列。为此,有两种常见的约定方式:行主序布局(C语言风格): 元素按行存储。对于矩阵 $A$,内存序列将是 1, 2, 3, 4, 5, 6。第一行的元素是连续的,接着是第二行的元素。这是NumPy使用的默认布局,JAX通常也采用此方式。列主序布局(Fortran语言风格): 元素按列存储。对于矩阵 $A$,内存序列将是 1, 4, 2, 5, 3, 6。第一列的元素是连续的,接着是第二列的元素。digraph MemoryLayout { rankdir=LR; node [shape=record, fontname="Arial"]; edge [style=invis]; // 用于间距的不可见边 subgraph cluster_row { label = "行主序布局\nA = [[1, 2, 3], [4, 5, 6]]"; bgcolor="#e9ecef"; node [style=filled, fillcolor="#a5d8ff"]; row_mem [label="<f0> 1 |<f1> 2 |<f2> 3 |<f3> 4 |<f4> 5 |<f5> 6"]; row_addr [label="地址: 0, 1, 2, 3, 4, 5", shape=plaintext]; row_mem -> row_addr [style=invis]; // 强制地址在下方 } subgraph cluster_col { label = "列主序布局\nA = [[1, 2, 3], [4, 5, 6]]"; bgcolor="#e9ecef"; node [style=filled, fillcolor="#ffc9c9"]; col_mem [label="<f0> 1 |<f1> 4 |<f2> 2 |<f3> 5 |<f4> 3 |<f5> 6"]; col_addr [label="地址: 0, 1, 2, 3, 4, 5", shape=plaintext]; col_mem -> col_addr [style=invis]; // 强制地址在下方 } }2x3矩阵采用行主序和列主序布局的线性内存表示。地址表示内存中的排列序列。对于更高维的数组,该原理同样适用。在行主序中,当您遍历连续内存位置时,最后的索引变化最快。在列主序中,最前的索引变化最快。布局为何对加速器很重要内存布局对性能的影响源于硬件获取数据的方式:CPU缓存: CPU将数据从主内存加载到更快的缓存中,以称为缓存行的连续数据块形式。按顺序访问数据(空间局部性)可提高缓存利用率。行主序布局天然有利于按行迭代的算法,而列主序则有利于按列迭代。GPU内存访问合并: GPU通过让线程组(warp)同时访问内存来实现高内存带宽。如果一个warp中的线程访问全局内存中的连续位置,这些访问通常可以合并为一个单一的、宽的内存事务。但是,如果访问是分散的(步进式的),则需要多个事务,从而大幅降低有效带宽。在行主序布局中,一个warp中不同线程沿最后一个维度(例如 A[i, j]、A[i, j+1]、A[i, j+2] 等)访问元素通常会实现合并访问,这是非常理想的。沿一列访问元素(A[i, j]、A[i+1, j] 等)通常会导致步进式、低效的访问。TPU内存架构: TPU拥有连接到专用矩阵乘法单元(MXU)的高带宽内存。当数据以特定的分块格式或布局呈现时,这些单元的运行效率最高。XLA的TPU后端处理布局转换以针对MXU进行优化,但初始布局会影响这些转换的额外开销。如果输入数据布局已经与TPU的偏好格式良好匹配,操作可能会明显更快。JAX和XLA中的布局默认情况下,JAX为其数组使用行主序布局,这与NumPy一致。然而,JAX在其编译代码中不会直接操作这些数组。在JIT编译期间,XLA会接管处理。XLA执行复杂的优化,包括布局分配。根据操作序列和目标硬件(CPU、GPU、TPU),XLA可能会内部改变中间数组的布局以提高性能。例如,XLA可能会将数组转置为在GPU或TPU上进行后续矩阵乘法时更高效的布局,即使原始输入是行主序。这意味着您通常无法(也通常不需要)明确控制JIT编译的JAX函数中每个中间数组的精确内存布局。XLA会做出这些决定。实际影响和考量虽然XLA自动处理大部分布局优化,但对于追求最高性能的高级用户来说,了解这些情况仍然很重要:识别内存瓶颈: 使用性能分析工具(前面已讨论)来判断您的应用程序是否受内存限制。涉及大数组的操作变慢可能表示内存访问模式效率低下,这可能与布局有关。理解硬件特性: 了解您的主要目标硬件的偏好访问模式。对于GPU,请考虑沿最后一个维度(在行主序中)的访问合并。对于TPU,请注意 jnp.einsum 或 jnp.matmul 等操作中某些数组形状和收缩维度可能更高效地映射到硬件。算法与数据结构设计: 尽管直接布局控制有限,但有时您可以设计算法或组织数据,以促进有利的访问模式。如果您在GPU上使用行主序数据时频繁需要按列访问,并且性能分析显示这是瓶颈,那么可以考虑是否提前转置一次数据能够整体有利,尽管有初始转置的成本。大数据传输: 在框架之间传输大数组(例如从NumPy到JAX)时,请注意布局。虽然JAX处理转换,但请理解NumPy可以表示列主序数组。确保数据在传递给JAX函数之前以标准行主序格式开始,可以避免潜在的初始转换开销。自定义操作: 如果您定义自定义原语或集成外部代码(第5章会讲到),您将直接负责为目标后端正确且高效地处理内存布局。总而言之,内存布局是加速器性能的一个重要因素。尽管JAX和XLA在许多情况下通过执行自动布局优化来抽象化直接控制,但理解行主序与列主序布局的原理,以及它们与GPU(合并访问)和TPU(分块/MXU效率)等硬件的配合方式,有助于您理解性能分析结果,并在优化核心代码时,设计出更适合硬件的算法。