趋近智
计算两组向量 (vector)之间的两两距离是一项常见的数值计算任务。对此任务应用优化原则,旨在提升其性能。目的不仅是为了提高速度,更是要了解在JAX/XLA生态系统中,为何某些方法表现更优。
假设我们有两组点,包含个点,包含个点,两者均处于维空间。我们希望计算距离矩阵,它代表中第个点与中第个点之间的欧几里得距离。
在JAX中,模拟NumPy实现此功能的一个直接方法可能涉及显式广播。
import jax
import jax.numpy as jnp
import timeit
# 生成一些示例数据
key = jax.random.PRNGKey(0)
N, M, D = 1000, 1500, 64
X = jax.random.normal(key, (N, D))
Y = jax.random.normal(key, (M, D))
# 确保数据在设备上再计时
X = jax.device_put(X)
Y = jax.device_put(Y)
def pairwise_distance_v1(X, Y):
"""使用显式广播计算两两距离。"""
# 为广播扩展维度:X 变为 (N, 1, D),Y 变为 (1, M, D)
diff = X[:, None, :] - Y[None, :, :]
# 计算欧几里得距离的平方
squared_dist = jnp.sum(diff**2, axis=-1)
# 返回平方根
return jnp.sqrt(squared_dist)
# 编译函数
pairwise_distance_v1_jit = jax.jit(pairwise_distance_v1)
# 首次运行以编译
_ = pairwise_distance_v1_jit(X, Y).block_until_ready()
# 基准测试
runs = 10
start_time = timeit.default_timer()
for _ in range(runs):
result_v1 = pairwise_distance_v1_jit(X, Y).block_until_ready()
elapsed_v1 = (timeit.default_timer() - start_time) / runs
print(f"基准JIT版本 (v1) 平均时间: {elapsed_v1:.6f} 秒")
# 示例输出 (会根据硬件而异):
# 基准JIT版本 (v1) 平均时间: 0.002512 秒
借助于JAX的NumPy风格API和JIT编译,此实现已相当高效。JAX会跟踪函数,将其转换为jaxpr,然后XLA将其编译为优化的核。广播操作([:, None, :]和[None, :, :])会创建中间数组,而jnp.sum执行归约。XLA的融合能力很可能会将其中一些操作合并。
我们可以使用矩阵代数表达欧几里得距离的平方:
这提示了一种涉及点积和平方和的替代计算方法。
def pairwise_distance_v2(X, Y):
"""使用矩阵代数恒等式计算两两距离。"""
# 计算X和Y中每个向量的平方范数
x_sq_norms = jnp.sum(X**2, axis=1) # 形状 (N,)
y_sq_norms = jnp.sum(Y**2, axis=1) # 形状 (M,)
# 计算所有向量对之间的点积
# X @ Y.T 得到一个 (N, M) 矩阵,其中元素 (i, j) 是 dot(X[i], Y[j])
dot_products = jnp.dot(X, Y.T)
# 使用恒等式计算平方距离:||x-y||^2 = ||x||^2 - 2*x.y + ||y||^2
# 我们需要重塑范数以进行广播:
# x_sq_norms[:, None] -> (N, 1)
# y_sq_norms[None, :] -> (1, M)
squared_dist = x_sq_norms[:, None] - 2 * dot_products + y_sq_norms[None, :]
# 处理由于浮点数精度问题可能出现的微小负值
squared_dist = jnp.maximum(0.0, squared_dist)
return jnp.sqrt(squared_dist)
# 编译函数
pairwise_distance_v2_jit = jax.jit(pairwise_distance_v2)
# 首次运行以编译
_ = pairwise_distance_v2_jit(X, Y).block_until_ready()
# 基准测试
start_time = timeit.default_timer()
for _ in range(runs):
result_v2 = pairwise_distance_v2_jit(X, Y).block_until_ready()
elapsed_v2 = (timeit.default_timer() - start_time) / runs
print(f"代数JIT版本 (v2) 平均时间: {elapsed_v2:.6f} 秒")
# 示例输出 (会根据硬件而异):
# 代数JIT版本 (v2) 平均时间: 0.001855 秒
为什么v2可能更快,尤其是在加速器上?
pairwise_distance_v1明确创建了一个大型中间数组diff,其形状为(N, M, D)。对于我们的示例大小 (1000, 1500, 64),这相当于1000 * 1500 * 64 = 96,000,000个元素。这会占用大量的内存带宽,这通常是GPU/TPU上的一个瓶颈。pairwise_distance_v2非常依赖于jnp.dot(X, Y.T)。矩阵乘法是一个基本操作,存在高度优化的核(例如NVIDIA GPU上的cuBLAS或特定的TPU核)。XLA可以有效地利用这些核。其他操作(平方和、广播加法/减法)通常是逐元素的,并且XLA通常可以与矩阵乘法或彼此之间进行有效融合。v1中的操作,但与v2的结构相比,大型中间张量可能会限制融合的程度或效率,v2将问题分解为与加速器硬件能力良好对应的操作(矩阵乘法、逐元素操作)。让我们通过代表性的计时数据来可视化潜在的性能差异:
两种JIT编译的两两距离实现的平均执行时间比较。值越低越好。
“在实际情况下,我们如何实现这种优化?”
v1) 开始,并应用 jax.jit。block_until_ready() 和计时器(笔记本中的 timeit 或 %timeit)获取可靠的性能数据。多次运行以平均消除噪音。jax.profiler.start_trace() / stop_trace():捕获可在TensorBoard中查看的执行轨迹。这有助于可视化操作持续时间,并找出计算中哪些部分耗时最多。在v1中,您可能会看到大量时间花在广播或归约操作上。jax.make_jaxpr 检查中间表示。虽然通常很冗长,但它可以显示JAX在XLA优化之前如何看待计算,可能突出显示大型中间结构。diff数组是否开销很大?”或者“我可以使用更直接的矩阵乘法吗?”。根据这些想法重构代码 (v2)。这种迭代的实现、基准测试、性能分析和重构过程对于有效优化JAX代码至关重要。了解操作如何映射到硬件能力以及XLA如何执行融合,可以帮助您编写JAX能够编译成高效例程的代码。请记住,“最佳”实现有时可能取决于具体的维度 (N, M, D) 和目标硬件 (CPU, GPU, TPU)。
这部分内容有帮助吗?
© 2026 ApX Machine LearningAI伦理与透明度•