计算两组向量之间的两两距离是一项常见的数值计算任务。对此任务应用优化原则,旨在提升其性能。目的不仅是为了提高速度,更是要了解在JAX/XLA生态系统中,为何某些方法表现更优。假设我们有两组点,$X$包含$N$个点,$Y$包含$M$个点,两者均处于$D$维空间。我们希望计算距离矩阵$D_{ij}$,它代表$X$中第$i$个点与$Y$中第$j$个点之间的欧几里得距离。$X \in \mathbb{R}^{N \times D}$ $Y \in \mathbb{R}^{M \times D}$ $D \in \mathbb{R}^{N \times M}$ $$ D_{ij} = \sqrt{\sum_{k=1}^{D} (X_{ik} - Y_{jk})^2} $$基准实现在JAX中,模拟NumPy实现此功能的一个直接方法可能涉及显式广播。import jax import jax.numpy as jnp import timeit # 生成一些示例数据 key = jax.random.PRNGKey(0) N, M, D = 1000, 1500, 64 X = jax.random.normal(key, (N, D)) Y = jax.random.normal(key, (M, D)) # 确保数据在设备上再计时 X = jax.device_put(X) Y = jax.device_put(Y) def pairwise_distance_v1(X, Y): """使用显式广播计算两两距离。""" # 为广播扩展维度:X 变为 (N, 1, D),Y 变为 (1, M, D) diff = X[:, None, :] - Y[None, :, :] # 计算欧几里得距离的平方 squared_dist = jnp.sum(diff**2, axis=-1) # 返回平方根 return jnp.sqrt(squared_dist) # 编译函数 pairwise_distance_v1_jit = jax.jit(pairwise_distance_v1) # 首次运行以编译 _ = pairwise_distance_v1_jit(X, Y).block_until_ready() # 基准测试 runs = 10 start_time = timeit.default_timer() for _ in range(runs): result_v1 = pairwise_distance_v1_jit(X, Y).block_until_ready() elapsed_v1 = (timeit.default_timer() - start_time) / runs print(f"基准JIT版本 (v1) 平均时间: {elapsed_v1:.6f} 秒") # 示例输出 (会根据硬件而异): # 基准JIT版本 (v1) 平均时间: 0.002512 秒借助于JAX的NumPy风格API和JIT编译,此实现已相当高效。JAX会跟踪函数,将其转换为jaxpr,然后XLA将其编译为优化的核。广播操作([:, None, :]和[None, :, :])会创建中间数组,而jnp.sum执行归约。XLA的融合能力很可能会将其中一些操作合并。另一种实现:应用矩阵代数我们可以使用矩阵代数表达欧几里得距离的平方: $||x_i - y_j||^2 = ||x_i||^2 - 2 x_i^T y_j + ||y_j||^2$这提示了一种涉及点积和平方和的替代计算方法。def pairwise_distance_v2(X, Y): """使用矩阵代数恒等式计算两两距离。""" # 计算X和Y中每个向量的平方范数 x_sq_norms = jnp.sum(X**2, axis=1) # 形状 (N,) y_sq_norms = jnp.sum(Y**2, axis=1) # 形状 (M,) # 计算所有向量对之间的点积 # X @ Y.T 得到一个 (N, M) 矩阵,其中元素 (i, j) 是 dot(X[i], Y[j]) dot_products = jnp.dot(X, Y.T) # 使用恒等式计算平方距离:||x-y||^2 = ||x||^2 - 2*x.y + ||y||^2 # 我们需要重塑范数以进行广播: # x_sq_norms[:, None] -> (N, 1) # y_sq_norms[None, :] -> (1, M) squared_dist = x_sq_norms[:, None] - 2 * dot_products + y_sq_norms[None, :] # 处理由于浮点数精度问题可能出现的微小负值 squared_dist = jnp.maximum(0.0, squared_dist) return jnp.sqrt(squared_dist) # 编译函数 pairwise_distance_v2_jit = jax.jit(pairwise_distance_v2) # 首次运行以编译 _ = pairwise_distance_v2_jit(X, Y).block_until_ready() # 基准测试 start_time = timeit.default_timer() for _ in range(runs): result_v2 = pairwise_distance_v2_jit(X, Y).block_until_ready() elapsed_v2 = (timeit.default_timer() - start_time) / runs print(f"代数JIT版本 (v2) 平均时间: {elapsed_v2:.6f} 秒") # 示例输出 (会根据硬件而异): # 代数JIT版本 (v2) 平均时间: 0.001855 秒分析与比较为什么v2可能更快,尤其是在加速器上?中间分配 (v1): 第一个版本pairwise_distance_v1明确创建了一个大型中间数组diff,其形状为(N, M, D)。对于我们的示例大小 (1000, 1500, 64),这相当于1000 * 1500 * 64 = 96,000,000个元素。这会占用大量的内存带宽,这通常是GPU/TPU上的一个瓶颈。优化核 (v2): 第二个版本pairwise_distance_v2非常依赖于jnp.dot(X, Y.T)。矩阵乘法是一个基本操作,存在高度优化的核(例如NVIDIA GPU上的cuBLAS或特定的TPU核)。XLA可以有效地利用这些核。其他操作(平方和、广播加法/减法)通常是逐元素的,并且XLA通常可以与矩阵乘法或彼此之间进行有效融合。XLA 融合: 尽管XLA可以融合v1中的操作,但与v2的结构相比,大型中间张量可能会限制融合的程度或效率,v2将问题分解为与加速器硬件能力良好对应的操作(矩阵乘法、逐元素操作)。让我们通过代表性的计时数据来可视化潜在的性能差异:{"data": [{"x": ["基准 (v1)", "代数 (v2)"], "y": [0.002512, 0.001855], "type": "bar", "marker": {"color": ["#fa5252", "#40c057"]}}], "layout": {"title": "两两距离计算时间 (JIT 编译)", "yaxis": {"title": "平均执行时间 (秒)"}, "xaxis": {"title": "实现版本"}, "template": "plotly_white", "width": 600, "height": 400}}两种JIT编译的两两距离实现的平均执行时间比较。值越低越好。优化过程“在实际情况下,我们如何实现这种优化?”基准与JIT: 从一个清晰、易读的实现 (v1) 开始,并应用 jax.jit。基准测试: 使用 block_until_ready() 和计时器(笔记本中的 timeit 或 %timeit)获取可靠的性能数据。多次运行以平均消除噪音。性能分析(找出瓶颈): 如果性能不令人满意,请使用JAX的性能分析工具:jax.profiler.start_trace() / stop_trace():捕获可在TensorBoard中查看的执行轨迹。这有助于可视化操作持续时间,并找出计算中哪些部分耗时最多。在v1中,您可能会看到大量时间花在广播或归约操作上。特定于设备的性能分析器(例如NVIDIA Nsight Systems):用于进行更深入的硬件级分析。检查 Jaxpr(可选): 使用 jax.make_jaxpr 检查中间表示。虽然通常很冗长,但它可以显示JAX在XLA优化之前如何看待计算,可能突出显示大型中间结构。假设与重构: 根据性能分析或对XLA/硬件的理解,形成假设。“大型中间diff数组是否开销很大?”或者“我可以使用更直接的矩阵乘法吗?”。根据这些想法重构代码 (v2)。重新基准测试: 比较新版本与基准版本的性能。这种迭代的实现、基准测试、性能分析和重构过程对于有效优化JAX代码至关重要。了解操作如何映射到硬件能力以及XLA如何执行融合,可以帮助您编写JAX能够编译成高效例程的代码。请记住,“最佳”实现有时可能取决于具体的维度 (N, M, D) 和目标硬件 (CPU, GPU, TPU)。