Now, let's put the theory into practice. We'll take a common numerical task, calculating pairwise distances between two sets of vectors, and apply the optimization principles discussed in this chapter. Our goal is not just to make it faster, but to understand why certain approaches perform better within the JAX/XLA ecosystem.
Suppose we have two sets of points, X with N points and Y with M points, both in a D-dimensional space. We want to compute a distance matrix Dij representing the Euclidean distance between the i-th point in X and the j-th point in Y.
X∈RN×D Y∈RM×D D∈RN×M
Dij=k=1∑D(Xik−Yjk)2A straightforward way to implement this in JAX, mimicking NumPy, might involve explicit broadcasting.
import jax
import jax.numpy as jnp
import timeit
# Generate some sample data
key = jax.random.PRNGKey(0)
N, M, D = 1000, 1500, 64
X = jax.random.normal(key, (N, D))
Y = jax.random.normal(key, (M, D))
# Ensure data is on the device before timing
X = jax.device_put(X)
Y = jax.device_put(Y)
def pairwise_distance_v1(X, Y):
"""Computes pairwise distances using explicit broadcasting."""
# Expand dimensions for broadcasting: X becomes (N, 1, D), Y becomes (1, M, D)
diff = X[:, None, :] - Y[None, :, :]
# Calculate squared Euclidean distance
squared_dist = jnp.sum(diff**2, axis=-1)
# Return the square root
return jnp.sqrt(squared_dist)
# Compile the function
pairwise_distance_v1_jit = jax.jit(pairwise_distance_v1)
# First run to compile
_ = pairwise_distance_v1_jit(X, Y).block_until_ready()
# Benchmark
runs = 10
start_time = timeit.default_timer()
for _ in range(runs):
result_v1 = pairwise_distance_v1_jit(X, Y).block_until_ready()
elapsed_v1 = (timeit.default_timer() - start_time) / runs
print(f"Baseline JIT version (v1) average time: {elapsed_v1:.6f} seconds")
# Example Output (will vary based on hardware):
# Baseline JIT version (v1) average time: 0.002512 seconds
This implementation is already quite efficient due to JAX's NumPy-like API and JIT compilation. JAX traces the function, converts it to jaxpr, and XLA compiles it into optimized kernels. The broadcasting operations ([:, None, :]
and [None, :, :]
) create intermediate arrays, and jnp.sum
performs the reduction. XLA's fusion capabilities will likely combine some of these operations.
We can express the squared Euclidean distance using matrix algebra: ∣∣xi−yj∣∣2=∣∣xi∣∣2−2xiTyj+∣∣yj∣∣2
This suggests an alternative computation involving dot products and sums of squares.
def pairwise_distance_v2(X, Y):
"""Computes pairwise distances using matrix algebra identity."""
# Calculate squared norms for each vector in X and Y
x_sq_norms = jnp.sum(X**2, axis=1) # Shape (N,)
y_sq_norms = jnp.sum(Y**2, axis=1) # Shape (M,)
# Calculate the dot products between all pairs of vectors
# X @ Y.T results in a (N, M) matrix where element (i, j) is dot(X[i], Y[j])
dot_products = jnp.dot(X, Y.T)
# Compute squared distances using the identity: ||x-y||^2 = ||x||^2 - 2*x.y + ||y||^2
# We need to reshape norms for broadcasting:
# x_sq_norms[:, None] -> (N, 1)
# y_sq_norms[None, :] -> (1, M)
squared_dist = x_sq_norms[:, None] - 2 * dot_products + y_sq_norms[None, :]
# Handle potential small negative values due to floating point inaccuracies
squared_dist = jnp.maximum(0.0, squared_dist)
return jnp.sqrt(squared_dist)
# Compile the function
pairwise_distance_v2_jit = jax.jit(pairwise_distance_v2)
# First run to compile
_ = pairwise_distance_v2_jit(X, Y).block_until_ready()
# Benchmark
start_time = timeit.default_timer()
for _ in range(runs):
result_v2 = pairwise_distance_v2_jit(X, Y).block_until_ready()
elapsed_v2 = (timeit.default_timer() - start_time) / runs
print(f"Algebraic JIT version (v2) average time: {elapsed_v2:.6f} seconds")
# Example Output (will vary based on hardware):
# Algebraic JIT version (v2) average time: 0.001855 seconds
Why might v2
be faster, especially on accelerators?
pairwise_distance_v1
explicitly creates a large intermediate array diff
of shape (N, M, D)
. For our example sizes (1000, 1500, 64), this is 1000 * 1500 * 64 = 96,000,000
elements. This can consume significant memory bandwidth, which is often a bottleneck on GPUs/TPUs.pairwise_distance_v2
relies heavily on jnp.dot(X, Y.T)
. Matrix multiplication is a fundamental operation for which highly optimized kernels (like cuBLAS on NVIDIA GPUs or specific TPU kernels) exist. XLA can effectively target these kernels. The other operations (summing squares, broadcasting additions/subtractions) are typically elementwise and can often be fused effectively by XLA with the matrix multiplication or with each other.v1
, the large intermediate tensor might limit the extent or efficiency of fusion compared to the structure of v2
, which breaks the problem down into operations (matrix multiply, elementwise operations) that map well to accelerator hardware capabilities.Let's visualize the potential performance difference (using hypothetical but representative timings):
Comparison of average execution times for the two JIT-compiled pairwise distance implementations. Lower is better.
How would we arrive at this optimization in a real-world scenario?
v1
) and apply jax.jit
.block_until_ready()
and a timer (timeit
or %timeit
in notebooks) to get reliable performance numbers. Run multiple times to average out noise.jax.profiler.start_trace()
/ stop_trace()
: To capture execution traces viewable in TensorBoard. This helps visualize operation durations and identify which parts of the computation take the most time. You might see significant time spent in broadcasting or reduction operations in v1
.jax.make_jaxpr
to examine the intermediate representation. While often verbose, it can reveal how JAX sees the computation before XLA optimization, potentially highlighting large intermediate structures.diff
array costly?" or "Can I use a more direct matrix multiplication?". Refactor the code (v2
) based on these ideas.This iterative process of implementing, benchmarking, profiling, and refactoring is essential for optimizing JAX code effectively. Understanding how operations map to hardware capabilities and how XLA performs fusion allows you to write code that JAX can compile into highly efficient routines. Remember that the "best" implementation can sometimes depend on the specific dimensions (N, M, D) and the target hardware (CPU, GPU, TPU).
© 2025 ApX Machine Learning